diff --git a/.gitignore b/.gitignore index 8c6455c1b..b37d0b1b5 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ xformers/csrc/attention/hip_fmha/*.hip xformers/csrc/attention/hip_fmha/*_hip.h xformers/csrc/attention/hip_fmha/instances/*.cu xformers/csrc/attention/hip_fmha/instances/*.hip -xformers/csrc/attention/hip_fmha/instances_tiled/*.cu -xformers/csrc/attention/hip_fmha/instances_tiled/*.hip +xformers/csrc/attention/hip_fmha/instances/*.cu +xformers/csrc/attention/hip_fmha/instances/*.hip +xformers/csrc/attention/hip_fmha/instances/*_hip.h diff --git a/setup.py b/setup.py index d74a53fba..f2f037f5b 100644 --- a/setup.py +++ b/setup.py @@ -418,6 +418,14 @@ def get_extensions(): "--ptxas-options=-allow-expensive-optimizations=true", ] elif torch.cuda.is_available() and torch.version.hip: + disable_hd256_hip_fmha = os.getenv("DISABLE_HD256_HIP_FMHA", "0") + if disable_hd256_hip_fmha == "1": + source_hip_maxk_256 = [] + for ff in source_hip: + if ff.endswith("maxk_256.cpp"): + source_hip_maxk_256 += [ff] + source_hip = list(set(source_hip) - set(source_hip_maxk_256)) + rename_cpp_cu(source_hip) rocm_home = os.getenv("ROCM_PATH") hip_version = get_hip_version(rocm_home) @@ -436,9 +444,16 @@ def get_extensions(): Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include" ] + use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT", "0") + generator_flag = [] + if disable_hd256_hip_fmha == "1": + generator_flag += ["-DFMHA_SUPPORT_MAX_HEADDIM_128=1"] cc_flag = ["-DBUILD_PYTHON_PACKAGE"] + if use_rtn_bf16_convert == "1": + cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=0"] + extra_compile_args = { "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": [ @@ -447,10 +462,12 @@ def get_extensions(): f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", - "-DCK_FMHA_FWD_FAST_EXP2=1", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-Werror", "-Woverloaded-virtual", + "-mllvm", + "-enable-post-misched=0", ] + generator_flag + cc_flag, diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 3f9e5d9d9..e1103c40a 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -38,13 +38,16 @@ if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") sm70_or_better_only = pytest.mark.skipif( - compute_capability < (7, 0), reason="requires sm70+" + torch.version.cuda is not None and compute_capability < (7, 0), + reason="requires sm70+", ) sm75_or_better_only = pytest.mark.skipif( - compute_capability < (7, 5), reason="requires sm75+" + torch.version.cuda is not None and compute_capability < (7, 5), + reason="requires sm75+", ) sm80_or_better_only = pytest.mark.skipif( - compute_capability < (8, 0), reason="requires sm80+" + torch.version.cuda is not None and compute_capability < (8, 0), + reason="requires sm80+", ) sm90_or_better_only = pytest.mark.skipif( compute_capability < (9, 0), reason="requires sm90+" @@ -670,16 +673,8 @@ def test_backward( if op_bw == fmha.ck.BwOp: op_fw = fmha.ck.FwOp - if dtype == torch.bfloat16: - pytest.skip( - "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" - ) if grad_out_contiguous is False: pytest.skip("CK Fmha does not support contiguous layout for grad_out!") - if k % 2 != 0: - pytest.skip( - "CK Fmha currently requires the headdim size of query input be an even value!" - ) qkv = None @@ -1586,7 +1581,7 @@ def test_decoder( # kv_heads = 1: multiquery # kv_heads = None: neither MQA nor GQA # kv_heads > 1: BMGHK - if dtype == "bf16" and compute_capability < (8, 0): + if dtype == "bf16" and torch.version.cuda and compute_capability < (8, 0): raise pytest.skip("BF16 is only supported on SM80+") import triton diff --git a/third_party/composable_kernel_tiled b/third_party/composable_kernel_tiled index e3f44659c..c8b6b6424 160000 --- a/third_party/composable_kernel_tiled +++ b/third_party/composable_kernel_tiled @@ -1 +1 @@ -Subproject commit e3f44659cf77df8c3de15eb14baffd58be6ac550 +Subproject commit c8b6b64240e840a7decf76dfaa13c37da5294c4a diff --git a/xformers/csrc/attention/hip_fmha/GENERATE_INSTANCES.md b/xformers/csrc/attention/hip_fmha/GENERATE_INSTANCES.md new file mode 100644 index 000000000..829df6646 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/GENERATE_INSTANCES.md @@ -0,0 +1,33 @@ + +# Instances generator + + The instances generator is a simple python tool used to generate several hundred of instances (.cpp files) and their references (.h files). + Without this tool, manually writing those instances and references will be very laborious and easy to get wrong. + + The instances generated by this scripts are divided into three categories visible from the scripts: + * Infer -- which refers to instances for calling inference-only kernels + * Forward -- which refers to instances for calling training forward kernels + * Backward -- which refers to instances for calling training backward kernels + + The instance generator is for being used by the HIP fmha developers themselves. It is not supposed to be used by the xformers users for + building xformers, since for xformers users, the instances are already well prepared as part of the xformers codes. + +## how to use instance generator + + * To generate complete instances supported by current implementation + + ``` + #> python xformers/csrc/attention/hip_fmha/generate_instances.py + ``` + * To generate reduced instances (when headdim256 is not required) + + ``` + #> python xformers/csrc/attention/hip_fmha/generate_instances.py --ignore-hd256 + ``` + * More options except for `--ignore-hd256` could be added to suppport further customization in generating instances as required + +## where the instances files are located + The instances files and references files are always located under a folder `instances/` that is located under the same directory + as the file `generate_instances.py` itself + + diff --git a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp index c9494060b..b470f5990 100644 --- a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp @@ -122,10 +122,6 @@ efficient_attention_backward_ck( int64_t K = query.size(3); int64_t Kv = value.size(3); - if (K % 2 != 0) - throw std::runtime_error( - "Currently CK Fmha requires the headdim of query/key be an even value!"); - auto opts = query.options(); at::Tensor grad_q, grad_k, grad_v, grad_bias; @@ -143,7 +139,6 @@ efficient_attention_backward_ck( grad_q = chunk.select(2, 0); grad_k = chunk.select(2, 1); grad_v = chunk.select(2, 2); - grad_q.fill_(0); } else if ( key.size(3) == value.size(3) && key.storage().is_alias_of(value.storage())) { @@ -157,14 +152,24 @@ efficient_attention_backward_ck( grad_v = chunk.select(2, 1); grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); - grad_q.fill_(0); } else { grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); grad_k = at::empty_strided(key.sizes(), key.strides(), key.options()); grad_v = at::empty_strided(value.sizes(), value.strides(), value.options()); - grad_q.fill_(0); } + at::Tensor grad_q_f32; + const bool use_grad_q_f32 = + (query.scalar_type() == at::ScalarType::BFloat16 || + query.scalar_type() == at::ScalarType::Half); + + if (use_grad_q_f32) { + grad_q_f32 = at::empty(grad_q.sizes(), opts.dtype(at::kFloat)); + grad_q_f32.fill_(0); + } else { + grad_q.fill_(0); + }; + // CK-FlashAttn requires q/k/v to have same shapes with dQ/dK/dV respectively TORCH_CHECK(query.sizes() == grad_q.sizes()); TORCH_CHECK(query.strides() == grad_q.strides()); @@ -211,7 +216,7 @@ efficient_attention_backward_ck( TORCH_CHECK(p.B == logsumexp.size(0)); TORCH_CHECK(p.Hq == logsumexp.size(1)); - TORCH_CHECK(p.M <= logsumexp.size(2)); + TORCH_CHECK(p.M == logsumexp.size(2)); if (scale.has_value()) { p.scale = float(*scale); @@ -229,6 +234,11 @@ efficient_attention_backward_ck( p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr(); p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr(); + if (use_grad_q_f32) + p.grad_q_f32_ptr = grad_q_f32.data_ptr(); + else + p.grad_q_f32_ptr = nullptr; + p.q_strides = { static_cast(query.stride(0)), static_cast(query.stride(1)), @@ -260,6 +270,14 @@ efficient_attention_backward_ck( static_cast(logsumexp.stride(1)), static_cast(logsumexp.stride(2))}; + if (use_grad_q_f32) { + p.grad_q_f32_strides = { + static_cast(grad_q_f32.stride(0)), + static_cast(grad_q_f32.stride(1)), + static_cast(grad_q_f32.stride(2)), + static_cast(grad_q_f32.stride(3))}; + } + if (is_mqa_gqa) { p.grad_k_strides = { static_cast(tmp_grad_k.stride(0)), @@ -335,9 +353,9 @@ efficient_attention_backward_ck( p.max_seqlen_q = *max_seqlen_q_; p.max_seqlen_k = *max_seqlen_k_; - TORCH_CHECK(p.num_batches == logsumexp.size(0)); + // unpadded lse layout required TORCH_CHECK(p.Hq == logsumexp.size(1)); - TORCH_CHECK(p.max_seqlen_q <= logsumexp.size(2)); + TORCH_CHECK(p.M == logsumexp.size(2)); if (scale.has_value()) p.scale = float(*scale); @@ -366,10 +384,16 @@ efficient_attention_backward_ck( static_cast(grad_out.stride(3))}; p.lsed_strides = { - static_cast(logsumexp.stride(0)), static_cast(logsumexp.stride(1)), static_cast(logsumexp.stride(2))}; + if (use_grad_q_f32) { + p.grad_q_f32_strides = { + static_cast(grad_q_f32.stride(1)), + static_cast(grad_q_f32.stride(2)), + static_cast(grad_q_f32.stride(3))}; + } + if (is_mqa_gqa) { p.grad_k_strides = { static_cast(tmp_grad_k.stride(1)), @@ -480,6 +504,11 @@ efficient_attention_backward_ck( p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr(); p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr(); p.grad_bias_ptr = bias_requires_grad ? grad_bias.data_ptr() : nullptr; + + if (use_grad_q_f32) + p.grad_q_f32_ptr = grad_q_f32.data_ptr(); + else + p.grad_q_f32_ptr = nullptr; }; auto inDataType = query.scalar_type(); diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp index 0cabf3f95..7f126dd33 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp @@ -57,7 +57,7 @@ template < int32_t ThreadsPerWavefront, int32_t WavefrontsPerBlock, int32_t KV_M_MAX = 8192, - int32_t K_MAX = 256> + int32_t K_MAX = K_MAX> at::Tensor& efficient_attention_forward_decoder_ck_out_impl( const at::Tensor& XQ, // [B, 1, G, H, D] const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp index fb29c7d21..4bbfe71ad 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp @@ -316,18 +316,14 @@ efficient_attention_forward_ck( p.dropout_prob = 0.0f; if (p.compute_logsumexp) { - // align the access of logsumexp by each thread-group in cache-line size - int aligned_seqlen_q = (p.max_seqlen_q + 15) / 16 * 16; - logsumexp = at::empty( - {p.num_batches, Hq, aligned_seqlen_q}, opts.dtype(at::kFloat)); + logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat)); p.logsumexp_ptr = logsumexp.data_ptr(); p.lse_strides = { - static_cast(logsumexp.stride(0)), static_cast(logsumexp.stride(1)), static_cast(logsumexp.stride(2))}; } else { p.logsumexp_ptr = nullptr; - p.lse_strides = {0, 0, 0}; + p.lse_strides = {0, 0}; } }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h index 4a535aa5a..8bcb29bee 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h @@ -23,6 +23,9 @@ template < bool kHasDropout, ck_tile::index_t MaxK> struct batched_backward_causalmask_bias_dropout_dispatch { + using FmhaBlockDropout = + typename FmhaBwdBlockDropoutMaker::dropout; + template using FmhaBwdPipelineProblemTemp = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::QDataType, @@ -42,12 +45,18 @@ struct batched_backward_causalmask_bias_dropout_dispatch { typename FmhaBwdTypeConfig::BiasGradDataType, FmhaBwdShape, false, // kIsGroupMode + false, // kIsDeterministic FmhaMask, + FmhaBlockDropout, FmhaTraits>; + static constexpr bool NeedConvertGradQ = !std::is_same< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType>::value; + static void Run(BatchedBackwardParams& param, hipStream_t stream) { { - constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockSize = 64; const bool pad_seqlen_q = !(param.M % kBlockSize == 0); const bool pad_headdim_v = @@ -76,9 +85,8 @@ struct batched_backward_causalmask_bias_dropout_dispatch { typename ck_tile::BlockFmhaBwdOGradDotO< FmhaBwdOGradDotOPipelineProblem>; - using FmhaBwdOGradDotOKernel_ = ck_tile::FmhaBwdOGradDotOKernel< - ck_tile::FmhaBwdOGradDotOTilePartitioner, - FmhaBwdOGradDotOPipeline>; + using FmhaBwdOGradDotOKernel_ = + ck_tile::FmhaBwdOGradDotOKernel; RunWithBwdOGradDotOKernel(param, stream); }); @@ -93,10 +101,6 @@ struct batched_backward_causalmask_bias_dropout_dispatch { using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - using FmhaBwdShape_ = FmhaBwdShape; - using FmhaBwdTilePartitioner_ = - ck_tile::FmhaBwdTilePartitioner; - constexpr auto kBiasEnum = kHasBias ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS : ck_tile::BlockAttentionBiasEnum::NO_BIAS; @@ -104,8 +108,10 @@ struct batched_backward_causalmask_bias_dropout_dispatch { constexpr bool kPadSeqLenQ = true; constexpr bool kPadSeqLenK = true; - const bool pad_headdim_q = !(param.K % FmhaBwdShape_::kQKHeaddim == 0); - const bool pad_headdim_v = !(param.Kv % FmhaBwdShape_::kVHeaddim == 0); + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_v = + !(param.Kv % FmhaBwdShape::kVHeaddim == 0); // usually headdim_q and headdim_v are same, consider them together // to determine whether to do padding saving some compiling time @@ -120,7 +126,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch { kBiasEnum, kHasBiasGrad, false, // kStoreLSE - kHasDropout, + false, // place-holder for kHasDropout, not used actually false, // kDoFp8StaticQuant place-holder occupancy>; @@ -128,7 +134,8 @@ struct batched_backward_causalmask_bias_dropout_dispatch { FmhaBwdPipelineProblemTemp; constexpr auto FmhaBwdPipelineEnum_ = - FmhaBwdPipelineEnumSelector::value; + FmhaBwdPipelineEnumSelector:: + value; using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< FmhaBwdPipelineEnum_, @@ -149,7 +156,6 @@ struct batched_backward_causalmask_bias_dropout_dispatch { kPadHeadDim>>; using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< - FmhaBwdTilePartitioner_, FmhaBwdPipeline_, FmhaBwdKGradEpilogue_, FmhaBwdVGradEpilogue_>; @@ -158,6 +164,46 @@ struct batched_backward_causalmask_bias_dropout_dispatch { }); }); }; + if constexpr (NeedConvertGradQ) { + constexpr ck_tile::index_t kBlockSize = 256; + + const bool pad_seqlen_q = !(param.M % kBlockSize == 0); + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] { + constexpr ck_tile::index_t occupancy = 2; + + using FmhaBwdConvertQGradTraits_ = + ck_tile::TileFmhaBwdConvertQGradTraits< + kPadSeqLenQ, + kPadHeadDimQ, + occupancy>; + + using FmhaBwdConvertQGradPipelineProblem = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + kBlockSize, + FmhaBwdShape::kM0, + FmhaBwdShape::kN0, + FmhaBwdShape::kQKHeaddim, + false, // kIsGroupMode + false, // kIsDeterministic + FmhaBwdConvertQGradTraits_>; + + using FmhaBwdConvertQGradPipeline = + typename ck_tile::BlockFmhaBwdConvertQGrad< + FmhaBwdConvertQGradPipelineProblem>; + + using FmhaBwdConvertQGradKernel_ = + ck_tile::FmhaBwdConvertQGradKernel; + + RunWithBwdConvertQGradKernel( + param, stream); + }); + }; } template @@ -208,10 +254,10 @@ struct batched_backward_causalmask_bias_dropout_dispatch { param.grad_out_ptr, param.dot_out_ptr, nullptr, // rand_val_ptr - param.grad_q_ptr, param.grad_k_ptr, param.grad_v_ptr, param.grad_bias_ptr, + NeedConvertGradQ ? param.grad_q_f32_ptr : param.grad_q_ptr, param.M, // seqlen_q param.N, // seqlen_k param.K, @@ -219,25 +265,29 @@ struct batched_backward_causalmask_bias_dropout_dispatch { param.Hq, param.Hq / param.Hkv, param.scale, - param.q_strides[1], // q, k, v, bias, do, dk, dv, dbias seq-dim - // stride + param.q_strides[1], // q, k, v, bias, do, dq_f32, dk, dv, dbias + // seq-dim stride param.k_strides[1], param.v_strides[1], param.attn_bias_strides[2], 0, // stride_randval param.grad_out_strides[1], + NeedConvertGradQ ? param.grad_q_f32_strides[1] : param.q_strides[1], param.grad_k_strides[1], param.grad_v_strides[1], param.attn_bias_strides[2], // assume grad_bias has same strides as // bias - param.q_strides[2], // q, k, v, bias, do, lse/dot, dbias - // nhead-dim strides + param.q_strides[2], // q, k, v, bias, do, lse/dot, dq_f32, dk, dv, + // dbias nhead-dim strides param.k_strides[2], param.v_strides[2], param.attn_bias_strides[1], 0, // nhead_stride_randval param.grad_out_strides[2], param.lsed_strides[1], + NeedConvertGradQ ? param.grad_q_f32_strides[2] : param.q_strides[2], + param.grad_k_strides[2], + param.grad_v_strides[2], param.attn_bias_strides[1], // assume grad_bias has same strides as // bias param.q_strides[0], // q, k, v, bias, do, lse/dot, dk, dv, dbias, @@ -248,16 +298,17 @@ struct batched_backward_causalmask_bias_dropout_dispatch { 0, // batch_stride_randval param.grad_out_strides[0], param.lsed_strides[0], // lse/dot is in BHM contiguous layout + NeedConvertGradQ ? param.grad_q_f32_strides[0] : param.q_strides[0], param.grad_k_strides[0], param.grad_v_strides[0], param.attn_bias_strides[0], // assume grad_bias has same strides as // bias + 0, // split_stride_dq_acc (param.window_size > 0) ? param.window_size - 1 : -1, // window_left_size (param.custom_mask_type == 0) ? -1 : 0, // window_right_size param.custom_mask_type, param.dropout_prob, // dropout ratio - false, // is_store_randval {param.philox_seed, param.philox_offset}); }(); @@ -270,6 +321,38 @@ struct batched_backward_causalmask_bias_dropout_dispatch { ck_tile::make_kernel( FmhaBwdDQDKDVKernel{}, kGridSize, kBlockSize, 0, kargs)); } + + template + static void RunWithBwdConvertQGradKernel( + BatchedBackwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaBwdConvertQGradKernel::MakeKargs( + param.grad_q_f32_ptr, + param.grad_q_ptr, + param.M, // seqlen_q + param.N, // seqlen_k + param.K, // headdim of q/k + param.q_strides[1], + param.grad_q_f32_strides[1], + param.q_strides[2], + param.grad_q_f32_strides[2], + param.q_strides[0], + param.grad_q_f32_strides[0], + 0); + }(); + + dim3 kGridSize = + FmhaBwdConvertQGradKernel::GridSize(param.B, param.Hq, param.M); + constexpr dim3 kBlockSize = FmhaBwdConvertQGradKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaBwdConvertQGradKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaBwdConvertQGradKernel{}, kGridSize, kBlockSize, 0, kargs)); + } }; template < diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp index a9e17ee73..3cf339b83 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp @@ -11,85 +11,7 @@ #include "ck_tiled_fmha_batched_backward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_backward_bf16_instances_ref.h" void batched_backward_bf16(BatchedBackwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp index 17c4aa9d3..807169ccd 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp @@ -11,85 +11,7 @@ #include "ck_tiled_fmha_batched_backward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); - -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch( - BatchedBackwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_backward_fp16_instances_ref.h" void batched_backward_fp16(BatchedBackwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp index e27552d3e..bd2e076e0 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp @@ -11,79 +11,7 @@ #include "ck_tiled_fmha_batched_forward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_forward_bf16_instances_ref.h" void batched_forward_bf16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp index a65f6a2a2..3c3791bdf 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp @@ -11,79 +11,7 @@ #include "ck_tiled_fmha_batched_forward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_forward_fp16_instances_ref.h" void batched_forward_fp16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h index 05d654dc3..36cf1b56e 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h @@ -69,7 +69,8 @@ struct batched_infer_causalmask_bias_dropout_dispatch { const bool pad_headdim = (pad_headdim_q || pad_headdim_v); const bool use_async_pipeline = - ((param.K % 8 == 0) && (param.Kv % 8 == 0) && (MaxK <= 128)); + (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && + (MaxK <= 128)); if (!use_async_pipeline) { BOOL_SWITCH_3( diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp index b362a780f..23b04d935 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp @@ -10,79 +10,7 @@ #include "ck_tiled_bool_switch.h" #include "ck_tiled_fmha_batched_infer.h" -// clang-format off -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_infer_bf16_instances_ref.h" void batched_infer_bf16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp index e55003c60..4e1d99e8e 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp @@ -10,79 +10,7 @@ #include "ck_tiled_bool_switch.h" #include "ck_tiled_fmha_batched_infer.h" -// clang-format off -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); - -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch( - BatchedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_batched_infer_fp16_instances_ref.h" void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h index 4ef24248a..9e2ba4818 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h @@ -8,6 +8,7 @@ #include #include +#include template struct FmhaBwdTypeConfig; @@ -55,94 +56,105 @@ struct FmhaBwdBlockTile; template <> struct FmhaBwdBlockTile<32> { - using type = ck_tile::sequence<128, 128, 32, 32, 32, 32, 32, 32, 32>; + using tile_lengths = ck_tile::sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2 using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3 - using gemm4_warps = ck_tile::sequence<4, 1, 1>; // default for gemm4 + using gemm4_warps = ck_tile::sequence<2, 2, 1>; // default for gemm4 }; template <> struct FmhaBwdBlockTile<64> { - using type = ck_tile::sequence<64, 128, 32, 32, 32, 32, 32, 64, 64>; + using tile_lengths = ck_tile::sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2 using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3 - using gemm4_warps = ck_tile::sequence<2, 2, 1>; // default for gemm4 + using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4 }; template <> struct FmhaBwdBlockTile<128> { - using type = ck_tile::sequence<64, 128, 32, 32, 32, 32, 32, 128, 128>; + using tile_lengths = + ck_tile::sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2 using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3 - using gemm4_warps = ck_tile::sequence<2, 2, 1>; // default for gemm4 + using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4 +}; + +template <> +struct FmhaBwdBlockTile<256> { + using tile_lengths = + ck_tile::sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; + using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2 + using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3 + using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4 }; -using FmhaBwdWarpTile = ck_tile::sequence<32, 32, 16>; +using FmhaBwdWarpTile1 = ck_tile::sequence<32, 32, 16>; +using FmhaBwdWarpTile2 = ck_tile::sequence<16, 16, 32>; +using FmhaBwdWarpTile3 = ck_tile::sequence<16, 16, 16>; template struct FmhaBwdShape; template <> struct FmhaBwdShape<32> : ck_tile::TileFmhaBwdShape< - typename FmhaBwdBlockTile<32>::type, + typename FmhaBwdBlockTile<32>::tile_lengths, typename FmhaBwdBlockTile<32>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<32>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<32>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<32>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<32>::gemm4_warps, - FmhaBwdWarpTile> {}; + FmhaBwdWarpTile2> {}; template <> struct FmhaBwdShape<64> : ck_tile::TileFmhaBwdShape< - typename FmhaBwdBlockTile<64>::type, + typename FmhaBwdBlockTile<64>::tile_lengths, typename FmhaBwdBlockTile<64>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<64>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<64>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<64>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<64>::gemm4_warps, - FmhaBwdWarpTile> {}; + FmhaBwdWarpTile2> {}; template <> struct FmhaBwdShape<128> : ck_tile::TileFmhaBwdShape< - typename FmhaBwdBlockTile<128>::type, + typename FmhaBwdBlockTile<128>::tile_lengths, typename FmhaBwdBlockTile<128>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<128>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<128>::gemm02_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile2, typename FmhaBwdBlockTile<128>::gemm13_warps, - FmhaBwdWarpTile, + FmhaBwdWarpTile3, typename FmhaBwdBlockTile<128>::gemm4_warps, - FmhaBwdWarpTile> {}; - -template -struct FmhaBwdPipelineEnumSelector; - -template <> -struct FmhaBwdPipelineEnumSelector<32> { - static constexpr ck_tile::BlockFmhaBwdPipelineEnum value = - ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS; -}; - -template <> -struct FmhaBwdPipelineEnumSelector<64> { - static constexpr ck_tile::BlockFmhaBwdPipelineEnum value = - ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR; -}; + FmhaBwdWarpTile2> {}; template <> -struct FmhaBwdPipelineEnumSelector<128> { +struct FmhaBwdShape<256> : ck_tile::TileFmhaBwdShape< + typename FmhaBwdBlockTile<256>::tile_lengths, + typename FmhaBwdBlockTile<256>::gemm02_warps, + FmhaBwdWarpTile2, + typename FmhaBwdBlockTile<256>::gemm13_warps, + FmhaBwdWarpTile3, + typename FmhaBwdBlockTile<256>::gemm02_warps, + FmhaBwdWarpTile2, + typename FmhaBwdBlockTile<256>::gemm13_warps, + FmhaBwdWarpTile3, + typename FmhaBwdBlockTile<256>::gemm4_warps, + FmhaBwdWarpTile2> {}; + +template +struct FmhaBwdPipelineEnumSelector { static constexpr ck_tile::BlockFmhaBwdPipelineEnum value = - ck_tile::BlockFmhaBwdPipelineEnum::KSVR; + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP; }; template @@ -150,19 +162,30 @@ struct FmhaBwdPipelineMaker; template struct FmhaBwdPipelineMaker< - ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, problem> { - using pipeline = ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS; + using pipeline = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; }; template struct FmhaBwdPipelineMaker< - ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, problem> { - using pipeline = ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR; + using pipeline = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; }; -template -struct FmhaBwdPipelineMaker { - using pipeline = ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR; +template +struct FmhaBwdBlockDropoutMaker; + +template +struct FmhaBwdBlockDropoutMaker { + using dropout = ck_tile::BlockDropoutBwd; +}; + +template +struct FmhaBwdBlockDropoutMaker { + using FmhaBwdShapeType = FmhaBwdShape; + static constexpr bool IsWG32 = + (FmhaBwdShapeType::Gemm0WarpTile::at(ck_tile::number<0>{}) == 32); + using dropout = ck_tile::BlockDropoutBwd; }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h index 662703b7e..ddd91a686 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h @@ -8,6 +8,7 @@ #include #include +#include template struct FmhaFwdTypeConfig; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h index b5038fdfe..82d9920f6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h @@ -23,6 +23,9 @@ template < bool kHasDropout, ck_tile::index_t MaxK> struct grouped_backward_causalmask_bias_dropout_dispatch { + using FmhaBlockDropout = + typename FmhaBwdBlockDropoutMaker::dropout; + template using FmhaBwdPipelineProblemTemp = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::QDataType, @@ -42,44 +45,47 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { typename FmhaBwdTypeConfig::BiasGradDataType, FmhaBwdShape, true, // kIsGroupMode + false, // non-deterministic FmhaMask, + FmhaBlockDropout, FmhaTraits>; + static constexpr bool NeedConvertGradQ = !std::is_same< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType>::value; + static void Run(GroupedBackwardParams& param, hipStream_t stream) { { - constexpr ck_tile::index_t kBlockSize = 256; - bool pad_seqlen_q = !(param.M % kBlockSize == 0); + constexpr ck_tile::index_t kBlockSize = 64; bool pad_headdim_v = !(param.Kv % FmhaBwdShape::kVHeaddim == 0); - BOOL_SWITCH_2( - pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { - constexpr ck_tile::index_t occupancy = 2; + constexpr bool kPadSeqLenQ = true; - using FmhaOGradDotOTraits_ = ck_tile::TileFmhaBwdOGradDotOTraits< - kPadSeqLenQ, - kPadHeadDimV, - occupancy>; + BOOL_SWITCH(pad_headdim_v, kPadHeadDimV, [&] { + constexpr ck_tile::index_t occupancy = 2; - using FmhaBwdOGradDotOPipelineProblem = - ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::DDataType, - kBlockSize, - FmhaBwdShape::kVHeaddim, - true, // kIsGroupMode - FmhaOGradDotOTraits_>; + using FmhaOGradDotOTraits_ = ck_tile:: + TileFmhaBwdOGradDotOTraits; - using FmhaBwdOGradDotOPipeline_ = - typename ck_tile::BlockFmhaBwdOGradDotO< - FmhaBwdOGradDotOPipelineProblem>; + using FmhaBwdOGradDotOPipelineProblem = + ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + kBlockSize, + FmhaBwdShape::kVHeaddim, + true, // kIsGroupMode + FmhaOGradDotOTraits_>; - using FmhaBwdOGradDotOKernel_ = ck_tile::FmhaBwdOGradDotOKernel< - ck_tile::FmhaBwdOGradDotOTilePartitioner, - FmhaBwdOGradDotOPipeline_>; + using FmhaBwdOGradDotOPipeline_ = + typename ck_tile::BlockFmhaBwdOGradDotO< + FmhaBwdOGradDotOPipelineProblem>; - RunWithBwdOGradDotOKernel(param, stream); - }); + using FmhaBwdOGradDotOKernel_ = + ck_tile::FmhaBwdOGradDotOKernel; + + RunWithBwdOGradDotOKernel(param, stream); + }); }; { @@ -92,10 +98,6 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - using FmhaBwdShape_ = FmhaBwdShape; - using FmhaBwdTilePartitioner_ = - ck_tile::FmhaBwdTilePartitioner; - constexpr auto kBiasEnum = kHasBias ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS : ck_tile::BlockAttentionBiasEnum::NO_BIAS; @@ -103,8 +105,10 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { constexpr bool kPadSeqLenQ = true; constexpr bool kPadSeqLenK = true; - const bool pad_headdim_q = !(param.K % FmhaBwdShape_::kQKHeaddim == 0); - const bool pad_headdim_v = !(param.Kv % FmhaBwdShape_::kVHeaddim == 0); + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_v = + !(param.Kv % FmhaBwdShape::kVHeaddim == 0); // usually headdim_q and headdim_v are same, consider them together // to determine whether to do padding saving some compiling time @@ -119,7 +123,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { kBiasEnum, kHasBiasGrad, false, // kStoreLSE - kHasDropout, + false, // place-holder for kHasDropout, not used actually false, // kDoFp8StaticQuant place-holder occupancy>; @@ -127,7 +131,8 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { FmhaBwdPipelineProblemTemp; constexpr auto FmhaBwdPipelineEnum_ = - FmhaBwdPipelineEnumSelector::value; + FmhaBwdPipelineEnumSelector:: + value; using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< FmhaBwdPipelineEnum_, @@ -148,7 +153,6 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { kPadHeadDim>>; using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< - FmhaBwdTilePartitioner_, FmhaBwdPipeline_, FmhaBwdKGradEpilogue_, FmhaBwdVGradEpilogue_>; @@ -157,6 +161,47 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { }); }); }; + + if constexpr (NeedConvertGradQ) { + constexpr ck_tile::index_t kBlockSize = 128; + + const bool pad_seqlen_q = true; + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] { + constexpr ck_tile::index_t occupancy = 2; + + using FmhaBwdConvertQGradTraits_ = + ck_tile::TileFmhaBwdConvertQGradTraits< + kPadSeqLenQ, + kPadHeadDimQ, + occupancy>; + + using FmhaBwdConvertQGradPipelineProblem = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + kBlockSize, + 64, // kM0 + 1, // kN0, no use + FmhaBwdShape::kQKHeaddim, + true, // kIsGroupMode + false, // kIsDeterministic + FmhaBwdConvertQGradTraits_>; + + using FmhaBwdConvertQGradPipeline = + typename ck_tile::BlockFmhaBwdConvertQGrad< + FmhaBwdConvertQGradPipelineProblem>; + + using FmhaBwdConvertQGradKernel_ = + ck_tile::FmhaBwdConvertQGradKernel; + + RunWithBwdConvertQGradKernel( + param, stream); + }); + }; } template @@ -175,8 +220,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { param.out_strides[0], // stride_o param.grad_out_strides[1], // nhead_stride_do param.out_strides[1], // nhead_stride_o - param.lsed_strides[1], - param.lsed_strides[0]); // batch_stride_d + param.lsed_strides[0]); // nhead_stride_d }(); dim3 kGridSize = FmhaBwdOGradDotOKernel::GridSize( @@ -205,10 +249,10 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { param.grad_out_ptr, param.dot_out_ptr, nullptr, // randval_ptr - param.grad_q_ptr, param.grad_k_ptr, param.grad_v_ptr, param.grad_bias_ptr, + NeedConvertGradQ ? param.grad_q_f32_ptr : param.grad_q_ptr, param.seqstart_q_dev_ptr, param.seqstart_k_dev_ptr, param.seqlen_k_dev_ptr, @@ -217,34 +261,37 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { param.Hq, param.Hq / param.Hkv, param.scale, - param.q_strides[0], // q, k, v, bias, do, dk, dv, dbias seq-dim - // stride + param.q_strides[0], // q, k, v, bias, do, dq_f32, dk, dv, dbias + // seq-dim stride param.k_strides[0], param.v_strides[0], param.attn_bias_strides[1], 0, // stride_randval param.grad_out_strides[0], + NeedConvertGradQ ? param.grad_q_f32_strides[0] : param.q_strides[0], param.grad_k_strides[0], param.grad_v_strides[0], param.attn_bias_strides[1], // assume grad_bias has same strides as - // bias - param.q_strides[1], // q, k, v, bias, do, lse/dot, dbias - // nhead-dim strides + // bias. + param.q_strides[1], // q, k, v, bias, do, lse/dot, dq_f32, dk, dv, + // dbias nhead-dim strides param.k_strides[1], param.v_strides[1], param.attn_bias_strides[0], 0, // nhead_stride_randval param.grad_out_strides[1], - param.lsed_strides[1], // assume lse/dot is in BHM contiguous layout + param.lsed_strides[0], // assume lse/dot is in HM contiguous layout + NeedConvertGradQ ? param.grad_q_f32_strides[1] : param.q_strides[1], + param.grad_k_strides[1], + param.grad_v_strides[1], param.attn_bias_strides[0], // assume grad_bias has same strides as // bias - param.lsed_strides[0], // batch_stride_lse + 0, // split_stride_dq_acc (param.window_size > 0) ? param.window_size - 1 : -1, // window_left_size (param.custom_mask_type == 0) ? -1 : 0, // window_right_size param.custom_mask_type, param.dropout_prob, // dropout ratio - false, // is_store_randval {param.philox_seed, param.philox_offset}); }(); @@ -258,6 +305,36 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { ck_tile::make_kernel( FmhaBwdDQDKDVKernel{}, kGridSize, kBlockSize, 0, kargs)); } + + template + static void RunWithBwdConvertQGradKernel( + GroupedBackwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaBwdConvertQGradKernel::MakeKargs( + param.grad_q_f32_ptr, + param.grad_q_ptr, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.K, // headdim of q/k + param.q_strides[0], + param.grad_q_f32_strides[0], + param.q_strides[1], + param.grad_q_f32_strides[1], + 0); + }(); + + dim3 kGridSize = FmhaBwdConvertQGradKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q); + constexpr dim3 kBlockSize = FmhaBwdConvertQGradKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaBwdConvertQGradKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaBwdConvertQGradKernel{}, kGridSize, kBlockSize, 0, kargs)); + } }; template < diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp index 5d08a4d72..7b77442be 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp @@ -11,85 +11,7 @@ #include "ck_tiled_fmha_grouped_backward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_backward_bf16_instances_ref.h" void grouped_backward_bf16(GroupedBackwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp index 266cd0ad1..be47bbdbb 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp @@ -11,85 +11,7 @@ #include "ck_tiled_fmha_grouped_backward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); - -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch( - GroupedBackwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_backward_fp16_instances_ref.h" void grouped_backward_fp16(GroupedBackwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h index 2fa305e0a..519a5ea89 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h @@ -150,9 +150,8 @@ struct grouped_forward_causalmask_bias_dropout_dispatch { param.v_strides[1], param.attn_bias_strides[1], 0, // nhead_stride_randval - param.lse_strides[1], + param.lse_strides[0], param.out_strides[1], - param.lse_strides[0], // batch_stride_lse (param.window_size > 0) ? param.window_size - 1 : -1, // window_left_size (param.custom_mask_type == 0) ? -1 : 0, // window_right_size diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp index e04af2e8a..28d75ddc5 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp @@ -11,79 +11,7 @@ #include "ck_tiled_fmha_grouped_forward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_forward_bf16_instances_ref.h" void grouped_forward_bf16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp index 13276415e..31e28bad6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp @@ -11,79 +11,7 @@ #include "ck_tiled_fmha_grouped_forward.h" #include "ck_tiled_headdim_switch.h" -// clang-format off -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_forward_fp16_instances_ref.h" void grouped_forward_fp16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h index 5197a6cb1..3805108c1 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h @@ -63,7 +63,8 @@ struct grouped_infer_causalmask_bias_dropout_dispatch { bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); const bool use_async_pipeline = - ((param.K % 8 == 0) && (param.Kv % 8 == 0) && (MaxK <= 128)); + (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && + (MaxK <= 128)); if (!use_async_pipeline) { BOOL_SWITCH_2( @@ -196,7 +197,6 @@ struct grouped_infer_causalmask_bias_dropout_dispatch { 0, // nhead_stride_randval 0, // nhead_stride_lse param.out_strides[1], - 0, // batch_stride_lse (param.window_size > 0) ? param.window_size - 1 : -1, // window_left_size (param.custom_mask_type == 0) ? -1 : 0, // window_right_size diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp index 5b0fb5b37..090227c1d 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp @@ -10,79 +10,7 @@ #include "ck_tiled_bool_switch.h" #include "ck_tiled_fmha_grouped_infer.h" -// clang-format off -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_infer_bf16_instances_ref.h" void grouped_infer_bf16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp index fa0a407f1..62c774ff5 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp @@ -10,79 +10,7 @@ #include "ck_tiled_bool_switch.h" #include "ck_tiled_fmha_grouped_infer.h" -// clang-format off -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); - -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch( - GroupedForwardParams& param, hipStream_t stream); -// clang-format on +#include "instances/fmha_grouped_infer_fp16_instances_ref.h" void grouped_infer_fp16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h index e97db1e86..ce86f6df4 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h @@ -28,9 +28,6 @@ struct BatchedInferParams { std::array out_strides; std::array attn_bias_strides; // 4d tensor_view [B, H, M, N] - // BHM mode strides, completely contiguous - std::array lse_strides; - const void* q_ptr; const void* k_ptr; const void* v_ptr; @@ -49,6 +46,9 @@ struct BatchedForwardParams : public BatchedInferParams { int64_t philox_seed; int64_t philox_offset; + // BHM mode strides, completely contiguous + std::array lse_strides; + // completely contiguous void* logsumexp_ptr; }; @@ -80,9 +80,6 @@ struct GroupedInferParams { // 4d tensor view [B, H, M, N] std::array attn_bias_strides; - // BHM mode strides, completely contiguous - std::array lse_strides; - const void* q_ptr; const void* k_ptr; const void* v_ptr; @@ -102,6 +99,10 @@ struct GroupedForwardParams : public GroupedInferParams { int64_t philox_seed; int64_t philox_offset; + // HM mode strides, completely contiguous, unpadded layout where M is + // concatten total seqlen_q for all batches + std::array lse_strides; + // completely contiguous void* logsumexp_ptr; }; @@ -132,6 +133,9 @@ struct BatchedBackwardParams { std::array grad_k_strides; std::array grad_v_strides; + // assume grad_q has same strides as q, but grad_q_f32 can be different + std::array grad_q_f32_strides; + // BHM mode strides, completely contiguous std::array lsed_strides; @@ -150,6 +154,8 @@ struct BatchedBackwardParams { void* grad_v_ptr; void* grad_bias_ptr; + void* grad_q_f32_ptr; + float dropout_prob; int64_t philox_seed; int64_t philox_offset; @@ -193,8 +199,12 @@ struct GroupedBackwardParams { std::array grad_k_strides; std::array grad_v_strides; - // BHM mode strides, completely contiguous - std::array lsed_strides; + // assume grad_q has same strides as q, but grad_q_f32 can be different + std::array grad_q_f32_strides; + + // HM mode strides, completely contiguous, unpadded layout where M is + // concatten total seqlen_q for all batches + std::array lsed_strides; const void* q_ptr; const void* k_ptr; @@ -211,6 +221,8 @@ struct GroupedBackwardParams { void* grad_v_ptr; void* grad_bias_ptr; + void* grad_q_f32_ptr; + float dropout_prob; int64_t philox_seed; int64_t philox_offset; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h index 18814324b..ce99023c9 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h @@ -9,6 +9,46 @@ #include #include +#ifndef FMHA_SUPPORT_MAX_HEADDIM_128 +#define FMHA_SUPPORT_MAX_HEADDIM_128 0 +#endif + +#if FMHA_SUPPORT_MAX_HEADDIM_128 + +#define FMHA_FWD_HEADDIM_SWITCH(HEAD_DIM1, HEAD_DIM2, CONST_NAME, ...) \ + [&] { \ + if (HEAD_DIM1 <= 32 && HEAD_DIM2 <= 32) { \ + constexpr ck_tile::index_t CONST_NAME = 32; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ + constexpr ck_tile::index_t CONST_NAME = 64; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ + constexpr ck_tile::index_t CONST_NAME = 128; \ + __VA_ARGS__(); \ + } else { \ + throw std::runtime_error("Head-dim sizes not supported!"); \ + } \ + }() + +#define FMHA_BWD_HEADDIM_SWITCH(HEAD_DIM1, HEAD_DIM2, CONST_NAME, ...) \ + [&] { \ + if (HEAD_DIM1 <= 32 && HEAD_DIM2 <= 32) { \ + constexpr ck_tile::index_t CONST_NAME = 32; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ + constexpr ck_tile::index_t CONST_NAME = 64; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ + constexpr ck_tile::index_t CONST_NAME = 128; \ + __VA_ARGS__(); \ + } else { \ + throw std::runtime_error("Head-dim sizes not supported!"); \ + } \ + }() + +#else + #define FMHA_FWD_HEADDIM_SWITCH(HEAD_DIM1, HEAD_DIM2, CONST_NAME, ...) \ [&] { \ if (HEAD_DIM1 <= 32 && HEAD_DIM2 <= 32) { \ @@ -39,7 +79,12 @@ } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ constexpr ck_tile::index_t CONST_NAME = 128; \ __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 256 && HEAD_DIM2 <= 256) { \ + constexpr ck_tile::index_t CONST_NAME = 256; \ + __VA_ARGS__(); \ } else { \ throw std::runtime_error("Head-dim sizes not supported!"); \ } \ }() + +#endif diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h index e930e0b82..715d5e4bd 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -34,6 +35,8 @@ struct FmhaRandUniformKernel { using BlockGemm = decltype(GetBlockGemm()); + using MyBlockDropout = ck_tile::BlockDropout; + static constexpr bool kPadSeqLenQ = true; static constexpr bool kPadSeqLenK = true; @@ -170,7 +173,7 @@ struct FmhaRandUniformKernel { } __device__ static constexpr ck_tile::index_t GetSmemSize() { - return ck_tile::BlockDropout::MakeRandValLdsBlockDescriptor() + return MyBlockDropout::MakeRandValLdsBlockDescriptor() .get_element_space_size(); } @@ -182,7 +185,7 @@ struct FmhaRandUniformKernel { RandValDramBlockWindowTmp& randval_dram_block_window_tmp) const { using namespace ck_tile; - auto randval_dram_window = BlockDropout::MakeRandvalDramWindow( + auto randval_dram_window = MyBlockDropout::MakeRandvalDramWindow( randval_dram_block_window_tmp, 0); const auto num_total_loop = @@ -201,17 +204,17 @@ struct FmhaRandUniformKernel { // randval tile in LDS auto randval_lds = make_tensor_view( reinterpret_cast(randval_smem_ptr), - BlockDropout::MakeRandValLdsBlockDescriptor()); + MyBlockDropout::MakeRandValLdsBlockDescriptor()); auto randval_lds_window = make_tile_window( randval_lds, - BlockDropout::MakeRandValLdsBlockDescriptor() + MyBlockDropout::MakeRandValLdsBlockDescriptor() .get_lengths(), {0, 0}); // register distribute auto randval_dist_generated = make_static_distributed_tensor( - BlockDropout::MakeRandValTileDistribution()); + MyBlockDropout::MakeRandValTileDistribution()); static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); @@ -219,7 +222,7 @@ struct FmhaRandUniformKernel { randval_lds_window.get_bottom_tensor_view(), randval_lds_window.get_window_lengths(), randval_lds_window.get_window_origin(), - BlockDropout::MakeRandValLdsShuffleTileDistribution()); + MyBlockDropout::MakeRandValLdsShuffleTileDistribution()); const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); diff --git a/xformers/csrc/attention/hip_fmha/generate_instances.py b/xformers/csrc/attention/hip_fmha/generate_instances.py index d2dc114d3..53dd8143c 100644 --- a/xformers/csrc/attention/hip_fmha/generate_instances.py +++ b/xformers/csrc/attention/hip_fmha/generate_instances.py @@ -6,11 +6,13 @@ # import os +import sys from pathlib import Path +from typing import List -FMHA_INSTANCE_HEADER = """ +FMHA_COPYRIGHT_HEADER = """ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -19,11 +21,13 @@ */ """ -FMHA_INFER_INSTANCE_TEMPLATE = """ +FMHA_INFER_INSTANCE_TEMPLATE_INC = """ #include #include \"ck_tiled_fmha_{mode}_infer.h\" +""" -template void run_{mode}_infer_causalmask_bias_dropout_dispatch< +FMHA_INFER_INSTANCE_TEMPLATE = """ +{extern}template void run_{mode}_infer_causalmask_bias_dropout_dispatch< {dtype}, {has_causalmask}, {has_bias}, @@ -36,11 +40,13 @@ "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) -FMHA_FORWARD_INSTANCE_TEMPLATE = """ +FMHA_FORWARD_INSTANCE_TEMPLATE_INC = """ #include #include \"ck_tiled_fmha_{mode}_forward.h\" +""" -template void run_{mode}_forward_causalmask_bias_dropout_dispatch< +FMHA_FORWARD_INSTANCE_TEMPLATE = """ +{extern}template void run_{mode}_forward_causalmask_bias_dropout_dispatch< {dtype}, {has_causalmask}, {has_bias}, @@ -53,11 +59,13 @@ "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) -FMHA_BACKWARD_INSTANCE_TEMPLATE = """ +FMHA_BACKWARD_INSTANCE_TEMPLATE_INC = """ #include #include \"ck_tiled_fmha_{mode}_backward.h\" +""" -template void run_{mode}_backward_causalmask_bias_dropout_dispatch< +FMHA_BACKWARD_INSTANCE_TEMPLATE = """ +{extern}template void run_{mode}_backward_causalmask_bias_dropout_dispatch< {dtype}, {has_causalmask}, {has_bias}, @@ -71,6 +79,8 @@ "{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) +FMHA_INSTANCE_REF_FNAME = "fmha_{mode}_{function}_{dtype}_instances_ref.h" + BOOL_MAP = {True: "true", False: "false"} BOOL_MAP_CAUSALMASK = { @@ -116,13 +126,13 @@ } -def create_infer_instances(instance_dir: Path) -> None: +def create_infer_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: for has_causalmask in [True, False]: for has_bias in [True, False]: for has_dropout in [True, False]: - for max_k in [32, 64, 128, 256]: + for max_k in headdims: fname = FMHA_INFER_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, @@ -133,9 +143,15 @@ def create_infer_instances(instance_dir: Path) -> None: has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) + infer_instance_inc = ( + FMHA_INFER_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + ) infer_instance = FMHA_INFER_INSTANCE_TEMPLATE.format( + extern="", mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], dtype=TYPE_CTYPE_MAP[dtype], has_causalmask=BOOL_MAP[has_causalmask], has_bias=BOOL_MAP[has_bias], @@ -144,17 +160,52 @@ def create_infer_instances(instance_dir: Path) -> None: cap_mode=MODE_NAME_MAP[mode], ) (instance_dir / fname).write_text( - FMHA_INSTANCE_HEADER + infer_instance + FMHA_COPYRIGHT_HEADER + + infer_instance_inc + + infer_instance ) -def create_forward_instances(instance_dir: Path) -> None: +def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: + for mode in ["batched", "grouped"]: + for dtype in ["fp16", "bf16"]: + ref_fname = FMHA_INSTANCE_REF_FNAME.format( + mode=mode, + function="infer", + dtype=dtype, + ) + ref_fname_path = instance_dir / ref_fname + infer_instance_inc = FMHA_INFER_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + with open(ref_fname_path, "a") as file: + file.write(FMHA_COPYRIGHT_HEADER) + file.write(infer_instance_inc) + for max_k in headdims: + for has_bias in [True, False]: + for has_dropout in [True, False]: + for has_causalmask in [True, False]: + infer_instance = FMHA_INFER_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causalmask=BOOL_MAP[has_causalmask], + has_bias=BOOL_MAP[has_bias], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) + file.write(infer_instance) + + +def create_forward_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: for has_causalmask in [True, False]: for has_bias in [True, False]: for has_dropout in [True, False]: - for max_k in [32, 64, 128, 256]: + for max_k in headdims: fname = FMHA_FORWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, @@ -165,9 +216,15 @@ def create_forward_instances(instance_dir: Path) -> None: has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) - infer_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format( + forward_instance_inc = ( + FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + ) + forward_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format( + extern="", mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], dtype=TYPE_CTYPE_MAP[dtype], has_causalmask=BOOL_MAP[has_causalmask], has_bias=BOOL_MAP[has_bias], @@ -176,11 +233,48 @@ def create_forward_instances(instance_dir: Path) -> None: cap_mode=MODE_NAME_MAP[mode], ) (instance_dir / fname).write_text( - FMHA_INSTANCE_HEADER + infer_instance + FMHA_COPYRIGHT_HEADER + + forward_instance_inc + + forward_instance ) -def create_backward_instances(instance_dir: Path) -> None: +def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: + for mode in ["batched", "grouped"]: + for dtype in ["fp16", "bf16"]: + ref_fname = FMHA_INSTANCE_REF_FNAME.format( + mode=mode, + function="forward", + dtype=dtype, + ) + ref_fname_path = instance_dir / ref_fname + forward_instance_inc = FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + with open(ref_fname_path, "a") as file: + file.write(FMHA_COPYRIGHT_HEADER) + file.write(forward_instance_inc) + for max_k in headdims: + for has_bias in [True, False]: + for has_dropout in [True, False]: + for has_causalmask in [True, False]: + forward_instance = ( + FMHA_FORWARD_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causalmask=BOOL_MAP[has_causalmask], + has_bias=BOOL_MAP[has_bias], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) + ) + file.write(forward_instance) + + +def create_backward_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: for has_causalmask in [True, False]: @@ -190,7 +284,7 @@ def create_backward_instances(instance_dir: Path) -> None: [False, False], ]: for has_dropout in [True, False]: - for max_k in [32, 64, 128]: + for max_k in headdims: fname = FMHA_BACKWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, @@ -202,9 +296,15 @@ def create_backward_instances(instance_dir: Path) -> None: has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) - infer_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format( + backward_instance_inc = ( + FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + ) + backward_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format( + extern="", mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], dtype=TYPE_CTYPE_MAP[dtype], has_causalmask=BOOL_MAP[has_causalmask], has_bias=BOOL_MAP[has_bias], @@ -214,14 +314,77 @@ def create_backward_instances(instance_dir: Path) -> None: cap_mode=MODE_NAME_MAP[mode], ) (instance_dir / fname).write_text( - FMHA_INSTANCE_HEADER + infer_instance + FMHA_COPYRIGHT_HEADER + + backward_instance_inc + + backward_instance ) +def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: + for mode in ["batched", "grouped"]: + for dtype in ["fp16", "bf16"]: + ref_fname = FMHA_INSTANCE_REF_FNAME.format( + mode=mode, + function="backward", + dtype=dtype, + ) + ref_fname_path = instance_dir / ref_fname + backward_instance_inc = FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + with open(ref_fname_path, "a") as file: + file.write(FMHA_COPYRIGHT_HEADER) + file.write(backward_instance_inc) + for max_k in headdims: + for has_bias, has_bias_grad in [ + [True, False], + [True, True], + [False, False], + ]: + for has_dropout in [True, False]: + for has_causalmask in [True, False]: + backward_instance = ( + FMHA_BACKWARD_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causalmask=BOOL_MAP[has_causalmask], + has_bias=BOOL_MAP[has_bias], + has_bias_grad=BOOL_MAP[has_bias_grad], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) + ) + file.write(backward_instance) + + if __name__ == "__main__": + disable_hd256 = False + + for arg in sys.argv: + if arg == "--ignore-hd256": + disable_hd256 = True + + if disable_hd256: + headdims = [32, 64, 128] + else: + headdims = [32, 64, 128, 256] + this_dir = os.path.dirname(__file__) output_dir = Path(this_dir) / "instances" output_dir.mkdir(parents=True, exist_ok=True) - create_infer_instances(output_dir) - create_forward_instances(output_dir) - create_backward_instances(output_dir) + + # remove existing files in the directory + files = os.listdir(output_dir) + for ff in files: + file_path = os.path.join(output_dir, ff) + os.remove(file_path) + + create_infer_instances(output_dir, headdims) + create_infer_instances_ref(output_dir, headdims) + create_forward_instances(output_dir, headdims) + create_forward_instances_ref(output_dir, headdims) + create_backward_instances(output_dir, headdims) + create_backward_instances_ref(output_dir, headdims) diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 97f209cb6..b129b0719 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..58aaac801 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 5c0e89e21..73360d7dc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 5e3392493..7f99b4819 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index ae9158e21..b831c919d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..1829f50f2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index dfc929276..74501e007 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index a915f8aa5..62a1c9d0b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 7e17c9298..b5b258196 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..070e8b2c0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 8d980af34..504c22609 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index be31aa59b..573d9bf4b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 7ea9cb0a9..67bf8995c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..4bc3b5a83 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index a2a9dd4d6..331b79140 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 594a62ff5..1c3a956d4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 0307f9ab2..0d902e120 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..13dfd5a09 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 5a7cd479a..e6b8fd85f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index e1280f6d2..4c2c0672e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 04a107af4..68bac14f2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..2a72588f1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 0a41a2f27..ea7baeea2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 49d6b9641..202882678 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h new file mode 100644 index 000000000..06f82124a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h @@ -0,0 +1,396 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index f5ce7c5bb..8689b5389 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..fd52bcc4d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 41ff265c7..2a5977be3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index f6b776650..490659b74 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 7f4013aaf..f4f3ac89c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..4067c8e5a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 5241a1b1f..c3dd3d5fe 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index f5ee944eb..d8fd52d7a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 8ab3f930c..f9e140aae 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..71b1586ac --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index c757b7d35..5688539e8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 4b3d9f256..a820ad76c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 03455ee6e..fbd6b8b48 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..b64b16b8d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 48a501539..db6ee679c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index d73c780a6..e79dd63df 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index c0636a905..35a968405 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..14d935611 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 3da3474df..783c741b6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 6ed11608d..7ddd65d11 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 3cca920f5..69e698344 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..5fa39c880 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 6383d494e..fed439c70 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 585dc69f3..6a955e982 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 6ca73178d..b4df2bf40 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..545a77955 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 95218766e..1da7bae3a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index bf092ff96..4c3cf7ff6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 394bbbe28..1cbafbf70 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..f1e9009d1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index ea3884557..951196506 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 4596bfd7f..75fef6ab4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index e1d72bc58..836e9428e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..cf89aa7bd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 96f62e9ac..bbc4eea82 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index dd72c62f2..2d804bd5d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index a0d7a83d9..3b85cea79 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..f261d64ba --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index e2d01f97e..635f9f1a2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index d5378b3f3..919a01fb9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 02c8c9bc5..bdf72b91a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..2588185d9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 8057c759e..087b8e1c8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index af6091b25..d01cb1e37 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 3fc748ff2..99a2823b4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..acceefffb --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index b9b6aacfe..ac3a2a5fd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 8b667d2f7..5a281913f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h new file mode 100644 index 000000000..d47f8cc1e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h @@ -0,0 +1,396 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 32>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index df1e6c3c0..68ffee4bf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..4d84693d6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index f415d9464..8b498600a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index ff8d33f21..7ddd6efd8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 41da7ab90..d1bdf1fa5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..b8c8eb5b3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 340fb65ee..60553e405 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index be7f2144d..dafd1d5d2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 0932fbb12..dd6ef7d00 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..daee39215 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index eaafd9949..dc1971262 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 02cf83aba..e9c8d75e3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 51bd8bedb..bc25646dc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..a324ea3d1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 7f999c203..8ffe3a4c3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 3ad410861..0d3ab043e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 90572aabf..64c0c14fb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..2d0e3efaa --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 9c0000820..003201abf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 13902640d..a6570b6bf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 82849155e..a23a7087d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..274405d53 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 81636cea6..46a8e8a4d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 97775f0e2..5bdd29dbd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 5a639ee11..189677f41 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 29cf57025..39881bd0d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index c60d415d4..a24b8868a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index f6291e2db..849a6633b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index caec04c71..c49a96edb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index ae29f02a3..f362ff83b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index 71eda93e9..62205efbd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index aa31f0f84..c485fdfcd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index 551c4eb67..68345b50d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 1d6e78baf..4e3144c61 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 278f6d358..1654eb535 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 18e12c0a4..fef0b43b9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index d393e26c3..87d8256c2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index e5e99ede0..521469e26 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index 672b58be1..d2eeed020 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index ed42d7c0b..77e509f0c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h new file mode 100644 index 000000000..8fab725be --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 7e71f6b27..b0898e658 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index 5f0af8c18..aee8358c1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 3aac80d51..b949c5557 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index 8018e467f..3e28448d4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index 0266d3a36..eae1bef14 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index d327faf63..3fea67a9d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index af2c6e8de..e9e1d8c03 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 722dc77bb..0b5b5e9ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index 9ab840b67..20e880ae3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 6b6c4b6a1..2d9e145b8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index afd3bcfc3..12c05851b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index a349964c0..296c93e84 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index 03eb236cc..ffcd7f0d8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 19dc010e4..a0fbb353f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 14272770f..729e834bf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index bf7aefc53..b2ee36ac2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 6e2e94259..e9c50c43e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index e08bb00a1..98ad34421 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index 96de7b864..df8cb489a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index f82f2b471..9ff6b6346 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index 60eda29ce..8e5fc2b22 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index 9cb7c591b..8489a8255 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index effc47a63..0ab15f431 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 477ec5f36..89b57dc00 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index b75a4f46f..286ce1f10 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 322d9c2e2..0a32ecd5e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 77fb6a604..5caa44509 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 57214e6f3..7b45b7050 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 3b4f1be34..ea683ccd0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index afc858efb..c17397faf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index bdf207633..6483bd6da 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index ea656db19..607227078 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h new file mode 100644 index 000000000..d69766972 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 5d65d7ae7..1af052fb6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index 709138805..5616cdc52 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index c50e52c86..8b10f1192 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index 1808842fc..988a2fe2b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index 367c420a4..9b5b928f7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 8f213bfef..1b36a0d25 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index fd5da6b77..785ecd397 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 70e0723bb..82199beb7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index 4f8e39ac1..e18cda6c9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 3d3be36e9..ed23610a9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index 21aae8f7c..2e512e089 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index 514a01a39..cfd204f04 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index c67d1c653..f161893bd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 810036325..c37fb70c9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 7dda46c89..f05aca856 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index 2392b9498..cd0f3d4ff 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 74743b024..ad22843e3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 20290bab8..a457b90f3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index ab3225bd4..51d21df17 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 310442726..0c2a21bf6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index af36d315e..4e33efc72 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index b25e1be08..f3eb7b0ec 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index 5e660a8ea..d8db2ebe2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 39153d92f..72e7fb412 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index bf3c3f21a..0b4ed8294 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index e9c1c0551..2e752c941 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index e35a1e7a5..68366ee2f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 577972843..9d0c50e13 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index bb48b49d2..8129cbf85 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index d13429529..3d6e897a4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index 5d44df43a..c264d95ad 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index aadd0fcca..fb8e9fb0a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h new file mode 100644 index 000000000..003d76894 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 034275f69..db28d72f4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index c922b00c0..228bb5397 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 8edd6fed5..d0152e160 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index e2d8ba101..8cb88dd94 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index 9e9adf31d..25c006c09 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 306829eaf..77ab1fc3e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 8bfc62104..15311470c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index fe81acab4..4c98864b2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index bcf5b783f..d20c61ee1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index ba5a41450..0410708e1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index 9cac1c3af..d837f7b54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index e31ed4362..7462600fb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index 9f52f52be..65d1fd39a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 9ba93c82c..c0ea4369a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index fec45193d..b46f0c0c8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index 571f8ad48..8051de4d9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 76447cfef..c1ee8c769 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 94e2e0dfc..46a38e82d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index 432d955b7..6040d41cd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 173d18aaf..db5d5d577 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index 7661a50d3..ccc0a0254 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index b3e43957f..d81ff0d38 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index f54aa9ef4..48b74b2bc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 17f4018c3..fda07f6cd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index d5ea02d7c..43069dd54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 2e4a6769e..bf8afd424 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 6caae1a75..351f5ea1d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index c01f1105b..d06dc1f10 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 4e146ec41..df91366da 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index e5bc54c2c..4c292918b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index ac3f5d082..9dc31e3ea 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index 3f39b0323..2bbd4f3dd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h new file mode 100644 index 000000000..266b3643e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 7440bc503..37f18fd7d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index efaf98472..dd5ec2118 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 0820075e5..3afe1c2f8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index 89dace195..e9ddc972d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index 95f57c099..609b4981c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index c8ac55329..5fca4f4ee 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 10a261f3d..fe3a2e2bc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 721145717..d077701b9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index be3100082..501a83e9a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 7c70e53b9..d0b619f60 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index 75f733259..af0bc1c85 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index 50507e69c..578454c52 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index 931040548..d20d225cd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index a1a08d4d5..ce76fd765 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 200706066..ca44ac6b0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index 9db040363..5d7589a16 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 72fec2837..c22b793d3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..f4b7a307a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 5b3551d3b..c5b1454c5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index c9ca1a559..c8c71960d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 09daabcfa..de55b8e88 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..577c43def --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 0bc605677..9ffa70e78 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 489610171..71ac1de6f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 3e9ba0cba..f2baaf01d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..18d194062 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 3e13c1b17..8e87f044d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index b5023fdc8..dbe7c0560 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 7c3a7a165..7a293a973 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..dc5f5c749 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 73cd48382..8b878747f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index f9163241f..1871a6cbe 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 55fa67c3d..295e3f403 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..e23b3c60b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 3549f1148..08af2d667 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index e8735e590..4d2d7e78d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 43586d91c..43fc95070 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..b85fa82e9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 6e6e44a15..86d8d4776 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 16c69fc8f..e8e862d54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h new file mode 100644 index 000000000..870b4dda9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h @@ -0,0 +1,396 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index c590ef5a4..76a4e7dcb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..a4b3c633d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 6e283c09f..1ba22ae61 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 6d3aebee2..07813b2c5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 62da5b2b3..42818cfa9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..07b019af4 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 28184d919..485b64775 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index a1cdf5607..ac1bccc14 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 36a047ac7..65b67988a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..81616d6af --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 3930123b2..9fc0a6c62 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 60bd6d5c7..dfbcd25be 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 549983dc4..8650510c3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..261017c52 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 8c32f736f..842c071d9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index e4a8919eb..1bf3602e3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index d88c4a1e0..302c566e7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..c3f030c5f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 8aeb02787..070e74116 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index a41d5eace..8011c547d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 324e1f0d0..249bf2a54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..9fed2aefc --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 630e0f72c..224d5f1bc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index b2b7066df..43fea8dee 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 9f7544038..dc70813fc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..10ae8c302 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index ab6c752ab..4fdbb099c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 988114605..e5d4365a1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 539311424..e028d1bee 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..3c47d406b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 34dd66471..1651af366 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 88305d7de..28fcbfad6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 4ff2f792b..34b227fad 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..ccd459e84 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 9534a7f50..20033dee2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 906dcd51b..c9dece923 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 926aadb7f..3b71014f6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..09ac8a84e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 5c29ff3c0..62df2f2dd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 75684001a..07514352b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 13e995979..c0d222f05 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..8d32e0b35 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index d41ee2d19..fe11f7f00 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 702a3bf4f..45ba2ddd3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index b450ef78d..e8e20cb4d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..81668563e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index be18be183..1961a1a29 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index b93c05261..ba07be603 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h new file mode 100644 index 000000000..367ca6bcf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h @@ -0,0 +1,396 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 32>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index fc26a3025..15e2f31d8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..00effd83c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 841cc31e5..de4030074 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index f2865241c..756c1dc18 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 35edebe38..7c5978f3f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..1dd5dfa0f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 8e0d32d5a..69ebd5833 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 573ec892b..3218e1606 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 33f9cace9..831e8b9ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..d7aeb937f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 683918a99..2659f809d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index e0c419d2f..466834030 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 52e41c45d..dc7f41755 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..8d1366511 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index acdf13265..07e60021b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 6729d5917..d562c0384 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 072115903..3b38e48f6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp new file mode 100644 index 000000000..cc9c0e377 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 64ff3db39..7237f3cab 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index f3acd7e17..7f7b87b46 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index d78c56731..fca2defab 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp new file mode 100644 index 000000000..247d2933f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 06dc769b9..952d91a05 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 63928f3a2..df612447f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 55e21c75a..436b35249 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 7c1c89f54..673ace243 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index 9453c7d2c..12f2dce03 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 888c865cd..b05db1117 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index 1e1231370..ac8a014bc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index 03625b779..2bb41cd3b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index b99a04d7a..8c17a20b7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 12c1b6a90..58357d0f8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index 42a6cea30..6b03e2ffd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 81d679689..b98a212b3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index e614abdaa..ba57b065d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 339f99255..6b5463311 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 64b61826f..c1b145ccd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index 4983a4ac1..ea2ee5082 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index fa7649dea..2b9b0559f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index 3a24474ba..6bad209f7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h new file mode 100644 index 000000000..4b1740f1a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 57e895ae9..222d1ed50 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index b975fa34c..bcad83e85 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 3be314a73..249011ee1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index 733debc01..15ac9062f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index b762d178c..4b833c8f8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 7d8648a26..3e07c1050 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 28a21d93f..276962324 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 2fe0721c6..f43d7b41c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index 159489e9d..1da0732d8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 507aabe2d..4891094bc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index db7d8ed17..d20de70d8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index c95898882..2e552a997 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index 4c5395bed..85f9097f5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 487acd8fa..456ae223a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 913d55757..51cbbf71d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index 137da7aaf..0614b84a2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 68a75552a..6db568b7c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 0603f0d1c..7c14a9f97 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index 2ba93fcc1..3ad15a89c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 4f95470a5..a0431622e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index c12483acf..3c5f652c7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index d2bb3b0f2..562298f72 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index 76752b2e6..9daf7f6c6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 2658965bc..1f3b70c84 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index 3715f9e40..1ce708426 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index df210e2b1..f765d967b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 0acee7775..65a976a9a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 91e6d0778..30b56e1b1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 4c2b6ca25..22ece8289 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index 5a2df731e..d5a7778e5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index 2492c47ea..bc5553560 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index 7cd86ff79..4b74c49ef 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h new file mode 100644 index 000000000..2ac28a520 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 892446459..b0918f683 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index e6914af9d..432cdd978 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 3acb390fe..b7f09b7c3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index b395d5671..8c6ad2498 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index a65035381..2b747e5e2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 547fef8b1..0d7c558cd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 8ec916502..3efca3798 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 1f3195d6e..dae892ab7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index 1498a7d09..d2020485e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 858d55e00..a29929b80 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index 72b4db4f8..d5f3cdffe 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index 237cbc71c..6a7482d69 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index a40d4a3a3..fc5604b5e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 9fb5462a0..f8741ae4f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 832ee6f82..8c4e8581b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index beaaaf75a..b29ac4d4f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index 23927f896..52e1d5d71 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index 7e0495247..055b769f9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index 59224bc65..9ce3756a6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 2917ab5d0..46d4e69b7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index ea651303e..5f11a042f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index f1b6c2762..3134e1c4c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index 631b007f7..f858eccb5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index 6bf62e163..5da3272f0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index e9d80dcba..ed632d7ea 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 629111cc2..d336cc52d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 03a582a51..7095195dd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index 8866842c5..312a64a29 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 0fc722d97..5747867dc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index d7654bcdb..f54dadca5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index aa8b341c5..a6b637a29 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index 14d6da36b..47abe27d9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h new file mode 100644 index 000000000..aa5c84146 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 2f4a65c57..95eb7e0ed 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index f7f7bde51..e9c361bd0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 3833d791c..5530bb928 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index b2c7d4be1..0a5592615 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index ab22cec47..5949924e4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 198837822..4ed017906 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 45d86f18a..d5df90946 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index be4cceb0c..8be8afd5e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index af14ace8f..441603639 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index 00fbb2563..39e2f9fed 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index e7c4b053e..6172df88a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index c9d263f8f..41681f180 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index da5ce48b5..98625d142 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 4cac3c509..9d3d73288 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index eacbac287..bb537cfe2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index e33f52717..66769f244 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp index c604204d2..4c35127f9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp index f4623e664..12a2a6105 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp index cb44bd3e6..885584ef4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp index 0f0e5290d..a11af5773 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp index 9b486ea34..8d1f0fb7f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp index 2154e1485..50577f7f9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp index 4d526353a..07fcfd2eb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp index bc14f586d..dc3690344 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp index 98567089a..b3727732a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp index 26211bc69..b8cb89622 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp index 72722bcf8..a4c2cacf1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp index c706a640c..2b36d6f33 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp index 58107a965..f3827c240 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp index 2b2c794f5..6627919bb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp index e8e3110f9..793fc5c90 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp index c50ad6f4e..2d50423e7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h new file mode 100644 index 000000000..f3a5d8501 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h @@ -0,0 +1,236 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp index 60e20d744..ffb1b36d6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp index e4eeebfcb..db5416d92 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp index 4b54aa562..d5cce31a7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp index 66e02cd50..bb3ad0e57 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp index 1c42f4206..2f6366584 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp index 46b4bd288..aed425ba5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp index 2ec8996f4..c3678b42f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp index 5e2a114a7..7481a9b9a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp index 88ad1f8dd..f6282217d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp index c536e0970..0564af6ec 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp index 0c927196b..afbe9a21f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp index e84f94f35..99e9133dc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp index 94db8d5d9..637d40bc1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp index 61abbbf36..ca8cb1bed 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp index 2a7b8f256..61f1540ae 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp index d5b1bd180..cad791039 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp @@ -1,6 +1,6 @@ /* - Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 8a16cfc3b..444326f93 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -11,7 +11,7 @@ import torch -from ..common import get_xformers_operator, register_operator +from ..common import get_operator, register_operator from . import attn_bias from .attn_bias import ( AttentionBias, @@ -142,7 +142,7 @@ def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int class FwOp(AttentionFwOpBase): """xFormers' MHA kernel based on Composable Kernel.""" - OPERATOR = get_xformers_operator("efficient_attention_forward_ck") + OPERATOR = get_operator("xformers", "efficient_attention_forward_ck") SUPPORTED_DEVICES: Set[str] = {"cuda"} SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 256 @@ -177,7 +177,7 @@ class FwOp(AttentionFwOpBase): } ERROR_RTOL: Mapping[torch.dtype, float] = { torch.float: 2e-5, - torch.half: 4e-4, + torch.half: 3e-3, torch.bfloat16: 2e-2, } @@ -345,10 +345,10 @@ def operator_flop( class BwOp(AttentionBwOpBase): __doc__ = FwOp.__doc__ - OPERATOR = get_xformers_operator("efficient_attention_backward_ck") + OPERATOR = get_operator("xformers", "efficient_attention_backward_ck") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES - SUPPORTED_MAX_K = 128 + SUPPORTED_MAX_K = 256 SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( type(None), torch.Tensor, @@ -366,12 +366,14 @@ class BwOp(AttentionBwOpBase): SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + SUPPORTS_UNPADDED_LSE = True NAME = "ckB" _TEST_K: List[int] = [ 32, # 64x64 kernel 64, 128, # 64x128/128x128 kernel + 256, ] @classmethod diff --git a/xformers/ops/fmha/ck_decoder.py b/xformers/ops/fmha/ck_decoder.py index b75c420fd..a5c820bfc 100644 --- a/xformers/ops/fmha/ck_decoder.py +++ b/xformers/ops/fmha/ck_decoder.py @@ -7,7 +7,7 @@ import torch -from ..common import get_xformers_operator, register_operator +from ..common import get_operator, register_operator from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask from .common import AttentionFwOpBase, Context, Inputs @@ -19,7 +19,7 @@ class FwOp(AttentionFwOpBase): Tested to work on MI250x. """ - OPERATOR = get_xformers_operator("efficient_attention_forward_decoder_ck") + OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_ck") SUPPORTED_DEVICES: Set[str] = {"cuda"} SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float} SUPPORTED_MAX_K: int = 256 diff --git a/xformers/ops/fmha/ck_splitk.py b/xformers/ops/fmha/ck_splitk.py index 6996da6c2..4c7af0794 100644 --- a/xformers/ops/fmha/ck_splitk.py +++ b/xformers/ops/fmha/ck_splitk.py @@ -7,7 +7,7 @@ import torch -from xformers.ops.common import get_xformers_operator, register_operator +from xformers.ops.common import get_operator, register_operator from xformers.ops.fmha.attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask from xformers.ops.fmha.common import ( AttentionFwOpBase, @@ -20,7 +20,7 @@ @register_operator class FwOp(AttentionFwOpBase): - OPERATOR = get_xformers_operator("efficient_attention_forward_decoder_splitk_ck") + OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_splitk_ck") SUPPORTED_DEVICES = {"cuda"} SUPPORTED_DTYPES = { torch.half,