Skip to content

Commit

Permalink
Improvement in ROCM fmha-backward (#1082)
Browse files Browse the repository at this point in the history
* Change the branch for composable_kernel_tiled submodule and update to latest

* Remove the using of seqlen_cpu in BwOp of ck.py

* Remove the using of seqlen_cpu in BwOp of ck.py

* Align .clang_format with main branch and re-format c++ files

* Synchronize to latest ck-tiled commit

* Add checking of IS_CK_TILED into some testing scripts

* Update to test_mem_eff_attention.py and ck.py

* Building xformers using ck-tiled as default

* ensure ck_decoder does not dispatch

* Add disable_on_rocm on some test scripts

* Update to test_mem_eff_attention.py

* apply isort

* apply black

* fix flake8 suggestions

* add license headers and reapply black

* Tiny update to rocm_ci.yml

* Add conditional compiling for cuda-depending codes in ROCM

* Update to benchmark scripts

* Rename the one script file

* Revert "Add conditional compiling for cuda-depending codes in ROCM"

This reverts commit 12fb41c.

* Update to scripts

* Change and add readme for tests and benchmarks

* Remove the stuffs for supporting old ck

* Remove old composable_kernel from submodule list

* Remove folder third_party/composable_kernel

* Rename the folder

* Remove unused script file

* apply black

* pacify mypy

* fix clang-format

* reapply black

* fix lints

* make test_splitk_reference run on cpu

* add ck modules to docs

* try fixing nvidia build by re-including sparse24 cpp folder into extension sources

* update cutlass to upstream commit

* update flash-attention to upstream commit

* simplify setup.py

* remove duplicate run_batched_infer_causalmask_attnbias_dispatched<f16, true, true, 128>

* add hip version and pytorch hip arch list to xformers build info

* fix build

* patch around the unhappy path in get_hip_version

* skip test_grad_checkpointing for triton_splitk since it doesn't have bwop

* re-enable test_mqa_forward since ck tiled is the current implementation

* make skip test_wrong_alignment more generic

* reapply black

* simplify test_decoder

* put python version check inside triton_splitk op

* fix logic

* cleanup python3.9 checks in tests

* cleanup test_attentions

* cleanup test_checkpoint as test running on cpu does not depend on gpu platform

* fix lints

* try fixing win build by conditional import of triton in triton op

* re-enable test_triton_layernorm as it passes

* re-enable test_triton_blocksparse as it passes

* cleanup test_sparse_tensors

* cleanup test_custom_ops

* reapply black

* cleanup test_core_attention

* benchmark ck ops on rocm only

* fix mypy

* fix lint: black

* fix lints: mypy

* split-k decoder: move all tunable parameters to the top of cpp file

* apply clang-format

* Rename HDim/headdim to MaxK/maxk

* Move some headers files to ck examples for later reusing

* Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is 256 for better performance

* rm test_ck_7

* dump kernel resource usage to compilation logs similar to nv

* Add the c++ extension to the latest change of ck_tile/dev fwd kernel (added droppout)

* Add the c++ extension to use ck_tile/dev/ fmha bwd kernel

* Update to add dropout for fmah backward

* Update in attention.cpp to align efficient_attention_backward_ck interface parameters

* Enable BwdOp in ck.py

* Support grad_out to have different strides as out

* Force seqstart_q/seqstart_k to be in device memory in ck.py

* Remove duplicated codes in ck_tiled_fmha_grouped_forward.h/infer.h

* Use optimized async pipeline where 8x headdim length is assumed

* Fix in batched_infer

* Update to track ck_tile/opt_padding_fa_train_xformers branch

* Update rocm_ci.yml

configuring the self-hosted runner

* Update to use the newer FmhaFwdEpilogue

* Update rocm_ci.yml

add option to manually trigger workflow

* Update rocm_ci.yml

remove condition which skips ci unless github event contains string 'rocm'

* copy rocm_ci workflow from main branch

* Update rocm_ci.yml

Bump upload-artifact version

* Update to use the newer FmhaFwdEpilogue for grouped infer/forward

* Temporarily disable the using of QRKSVSAsync() pipeline

* Update rocm_ci.yml

add a daily run

* Implement the ck_rand_uniform interface for generating random number tensor

* Add dropout to the infer path (needed by xformers test_dropout)

* Update to support test_dropout and test_dropout_backward tests

* Update the padding method in batched_backward.h

* Update the OGradDotO kernel padding method

* Change the backward padding checking condition

* Add batch_stride_lse/d parameters to adapt grouped mode forward/backward to [num_batches, H, MaxSeqlenQ] layout

* Fill the grad_bias in advance

* Add support for kHasBiasGrad as instance template

* Remove using hdim_stride_do in fmha backward

* Force kPadSeqLenQ/kPadSeqLenK to be true in batched-backward to save compiling time

* Fix missing passing of {philox_seed, philox_offset} in inference path

* Use SimplifiedGenericAttentionMask to replace GenericAttentionMask

* Shorten the instance file names

* Rename the template parameters

* Simplify the names of the dispatch class and interfaces

* Changes to reuse the kernel files under ck_tile examples/91_tile_program/fmha folder

* Update test_mem_eff_attention.py for test_dropout/test_dropout_backward/test_backward on rocm

* Tiny change to the philox_cuda_state input setting

* Allocate logsumexp to ensure aligned access by each thread-group

* Add checking for query/key headdim size attention_backward_generic

* Using ck_tile/opt_padding_fa_train_pr2 and synchronize the backward codes with the changes

* Enable using async pipeline in the batched inference path for performance

* Re-organize cpp instances for calling fmha infer kernel

* Re-organize cpp instances for calling fmha forward kernel

* Re-organize cpp instances for calling fmha backward kernel

* Position the composable_kernel_tiled to ck_tile/opt_padding_fa_train branch

* Update to synchronize with the latest commits in ck_tile/opt_padding_fa_train

* update submodule to public

* Update to the criteria for padding seqlen_k in batched infer/forward

* Keep latest track of ck-tile commits

* Tiny fixing to the decoder including

* Position the ck-tiled to ck_tile/opt_padding branch

* Enable some attn_bias types which were previously disabled by old-ck in ck.py

* Add script generate_instances.py which helps to generate instances

* Simplify logic for seqstart_q/k

ROCm@566d26f has put the seqstart_k/q on device. So simplify the logic here.

The upstream xformers don't have this optmization and is copying the seqstart_q/k every iterations. We'd like this change to get in and then merge to upstream.

* Add Async pipeline to grouped mode inference path

* Use explict true for kPadSeqLenQ/kPadHeadDimQ/kPadHeadDimV templates for the Async pipeline

* Synchronize to the update of composable_kernel_tiled for better performance

* Update rocm_ci.yml - clean up dangling images after ci run

* Avoid unused-const-variable warning

Our compiler will error on unused-const-variable warning. So just fix this

* Tiny change in the BlockTile/Shape setting overriddings

* try to align fmha C++ extension to the ck_tile in ck develop branch

* Synchronize composable_kernel_tiled to latest ck develop

* Use FmhaFwdTilePartitioner_HBS only with seqlen_k padded cases

* Tiny fix/change to make test_forward/test_backward/test_dropout/test_dropout_backward_ck pass

* Fix compiling issue with regard to Invoker definitions in forward_decoder/forward_decoder_split operators

* Keep using -Woverloaded-virtual

* Fix clang-format for headers and cpp files

* Fix format in python scripts

* Add noqa: C801 for generate_instances.py

* Align dispatch_bw with main branch

* Align ops/fmha/common.py with main branch

* Synchronize the thirty-party/composable_kernel_tiled to latest ck_tile commits for better performance

* Relax the atol for test_forward and test_dropout due to the using of packed fp16_2_fp32 conversion in ck_tile

* Generate html report for tests run with rocm_ci.yml

* archive test results when tests have failed

* Always clean up dangling docker images in rocm_ci

* Bump python to 3.11 in rocm_ci.yml

* Disable flash attention tests rocm_ci.yml

Since the op is broken; tbd either make the op work, or disable it on ROCm

* Try to fix rocm_ci.yml

Init must be called before activation

* try to fix rocm_ci.yml flow by overriding PATH

* Fix setup.py path in rocm_ci.yml

* cd to xformers dir before running install in rocm_ci.yml

* Use pip to install xformers in rocm_ci.yml

* Possibly fix python version resolution in rocm_ci.yml

* Set the correct path for pytest in rocm_ci.yml

* remove test_reference_splitk as it was moved to a different file during the first upstream

remove test_mqa_forward from develop, as the test fails in develop and doesn't run upstream

remove reference attention splitk from the test file; it exists in test_splitk_reference

sync test_mem_eff_attention with upstream

* make sure ck operators have a name to be visible in the dispatcher

* fix sm version checks to happen only on CUDA, not ROCm

* (2/n) fix sm version checks to happen only on CUDA, not ROCm

* Remove _check_large_shapes checking in fmha/ck.py (#1067)

* make xformers install editable to fix cpp extensions detection

* Update to using the improved fmha-bwd (compiling passed)

* Update to get 80% of the test_backward and test_dropout_backward_ck cases passed

* Replace the using of ConvertGradQ by using torch tensor type converting

* Change the tile settings for MaxK=32

* Fix padding setting bug in grouped_backward

* Change -DCK_FMHA_FWD_FAST_EXP2=1 to -DCK_TILE_FMHA_FWD_FAST_EXP2=1

* Point the composable_kernel_tiled submodule to ck_tile/fa_bwd_opt branch

* Disable flshattF and flshattB on ROCM

* Add -mllvm and -enable-post-misched=0 compiling options for ROCM on setup.py

* Disable flshattF and flshattB on ROCM

* Update to support separate grad_q_f32_strides do to the API change in the fmd_bwd_kernel

* Use old method for setting BlockDropout due to the revert in fmha_fwd_kernel

* Tiny fix in grouped_backward

* Use packed tensor allocation for grad_q_f32

* Update to the ConvertGradQ kernel calling

* Tiny update

* Fix the parameter location in grouped_backward

* Adjust headdim128 tile shapes for better performance

* Update backward kernel calling due to adding of nhead_stride_dk/nhead_stride_dv parameters

* Synchronize with CK to use separate pipeline for kPadHeadDim true of false situtation

* Use convertDQ kernel

* Update to use unpadded lse layout

* Add explicit headdim256 instances for fmha backward

* Add leaked headdim256 instance references

* Change to generate.py and the re-generate the instance files using it

* Change to generate.py to generate instances refences and uses the generated reference headers

* Relax the RTOL of ckFwOp from 4e-4 to 3e-3 due to one big result case

* Change to use .h rather than .hpp as suffix for generated header files

* Fix in .gitignore

* Update to bwd setting to use only IGLP pipeline

* Synchronize to latest ck_tile fix and align the headdim64 tile shape setting

* Reformat the generated instances cpp files

* Fix to the backward Trait

* Set occupancy to -1 to avoid the compiling warning

* Revert "Set occupancy to -1 to avoid the compiling warning"

This reverts commit fa6d8b3.

* Add environment variable and compiler definition to control the generating of headdim256 instances

* Add --ignore-hd256 argument to generate_instance.py and some update in this script

* Add environment variable ENABLE_HIP_FMHA_RTN_BF16_CONVERT to enable using rtn bf16 conversion

* Remove commented lines in test_mem_eff_attention.py

* Synchronize to latest ck_tile commit

* apply black

* apply flake8

* fix mypy

* revert disable flash operator on rocm

* Synchronize to ck_tile latest commit again

* Re-position the composable_kernel submodule to the develop branch

* Avoid the Async pipeline when khasBias is true

* clang-format for two files

* Change allocation of grouped mode lse from [H, M] to [1, H, M] to match the xformers scripts

* Synchronize to the upstream rocm_ci workflows

* Re-format tests/test_mem_eff_attention.py

* Change in generate_instances.py so that this scripts can be called from flexible location

* Add GENERATE_INSTANCES.md

* clean-up commented codes

* Remove un-used test

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Xiaodong Wang <xw285@cornell.edu>
  • Loading branch information
4 people authored Aug 22, 2024
1 parent 7f42efb commit e3900ba
Show file tree
Hide file tree
Showing 494 changed files with 5,525 additions and 1,501 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ xformers/csrc/attention/hip_fmha/*.hip
xformers/csrc/attention/hip_fmha/*_hip.h
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances_tiled/*.cu
xformers/csrc/attention/hip_fmha/instances_tiled/*.hip
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances/*_hip.h

19 changes: 18 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ def get_extensions():
"--ptxas-options=-allow-expensive-optimizations=true",
]
elif torch.cuda.is_available() and torch.version.hip:
disable_hd256_hip_fmha = os.getenv("DISABLE_HD256_HIP_FMHA", "0")
if disable_hd256_hip_fmha == "1":
source_hip_maxk_256 = []
for ff in source_hip:
if ff.endswith("maxk_256.cpp"):
source_hip_maxk_256 += [ff]
source_hip = list(set(source_hip) - set(source_hip_maxk_256))

rename_cpp_cu(source_hip)
rocm_home = os.getenv("ROCM_PATH")
hip_version = get_hip_version(rocm_home)
Expand All @@ -436,9 +444,16 @@ def get_extensions():
Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include"
]

use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT", "0")

generator_flag = []
if disable_hd256_hip_fmha == "1":
generator_flag += ["-DFMHA_SUPPORT_MAX_HEADDIM_128=1"]

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
if use_rtn_bf16_convert == "1":
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=0"]

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": [
Expand All @@ -447,10 +462,12 @@ def get_extensions():
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-Werror",
"-Woverloaded-virtual",
"-mllvm",
"-enable-post-misched=0",
]
+ generator_flag
+ cc_flag,
Expand Down
19 changes: 7 additions & 12 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
sm70_or_better_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
torch.version.cuda is not None and compute_capability < (7, 0),
reason="requires sm70+",
)
sm75_or_better_only = pytest.mark.skipif(
compute_capability < (7, 5), reason="requires sm75+"
torch.version.cuda is not None and compute_capability < (7, 5),
reason="requires sm75+",
)
sm80_or_better_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
torch.version.cuda is not None and compute_capability < (8, 0),
reason="requires sm80+",
)
sm90_or_better_only = pytest.mark.skipif(
compute_capability < (9, 0), reason="requires sm90+"
Expand Down Expand Up @@ -670,16 +673,8 @@ def test_backward(

if op_bw == fmha.ck.BwOp:
op_fw = fmha.ck.FwOp
if dtype == torch.bfloat16:
pytest.skip(
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!"
)
if grad_out_contiguous is False:
pytest.skip("CK Fmha does not support contiguous layout for grad_out!")
if k % 2 != 0:
pytest.skip(
"CK Fmha currently requires the headdim size of query input be an even value!"
)

qkv = None

Expand Down Expand Up @@ -1586,7 +1581,7 @@ def test_decoder(
# kv_heads = 1: multiquery
# kv_heads = None: neither MQA nor GQA
# kv_heads > 1: BMGHK
if dtype == "bf16" and compute_capability < (8, 0):
if dtype == "bf16" and torch.version.cuda and compute_capability < (8, 0):
raise pytest.skip("BF16 is only supported on SM80+")
import triton

Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
33 changes: 33 additions & 0 deletions xformers/csrc/attention/hip_fmha/GENERATE_INSTANCES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

# Instances generator

The instances generator is a simple python tool used to generate several hundred of instances (.cpp files) and their references (.h files).
Without this tool, manually writing those instances and references will be very laborious and easy to get wrong.

The instances generated by this scripts are divided into three categories visible from the scripts:
* Infer -- which refers to instances for calling inference-only kernels
* Forward -- which refers to instances for calling training forward kernels
* Backward -- which refers to instances for calling training backward kernels

The instance generator is for being used by the HIP fmha developers themselves. It is not supposed to be used by the xformers users for
building xformers, since for xformers users, the instances are already well prepared as part of the xformers codes.

## how to use instance generator

* To generate complete instances supported by current implementation

```
#> python xformers/csrc/attention/hip_fmha/generate_instances.py
```
* To generate reduced instances (when headdim256 is not required)

```
#> python xformers/csrc/attention/hip_fmha/generate_instances.py --ignore-hd256
```
* More options except for `--ignore-hd256` could be added to suppport further customization in generating instances as required

## where the instances files are located
The instances files and references files are always located under a folder `instances/` that is located under the same directory
as the file `generate_instances.py` itself

Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ efficient_attention_backward_ck(
int64_t K = query.size(3);
int64_t Kv = value.size(3);

if (K % 2 != 0)
throw std::runtime_error(
"Currently CK Fmha requires the headdim of query/key be an even value!");

auto opts = query.options();

at::Tensor grad_q, grad_k, grad_v, grad_bias;
Expand All @@ -143,7 +139,6 @@ efficient_attention_backward_ck(
grad_q = chunk.select(2, 0);
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
grad_q.fill_(0);
} else if (
key.size(3) == value.size(3) &&
key.storage().is_alias_of(value.storage())) {
Expand All @@ -157,14 +152,24 @@ efficient_attention_backward_ck(
grad_v = chunk.select(2, 1);

grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_q.fill_(0);
} else {
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_k = at::empty_strided(key.sizes(), key.strides(), key.options());
grad_v = at::empty_strided(value.sizes(), value.strides(), value.options());
grad_q.fill_(0);
}

at::Tensor grad_q_f32;
const bool use_grad_q_f32 =
(query.scalar_type() == at::ScalarType::BFloat16 ||
query.scalar_type() == at::ScalarType::Half);

if (use_grad_q_f32) {
grad_q_f32 = at::empty(grad_q.sizes(), opts.dtype(at::kFloat));
grad_q_f32.fill_(0);
} else {
grad_q.fill_(0);
};

// CK-FlashAttn requires q/k/v to have same shapes with dQ/dK/dV respectively
TORCH_CHECK(query.sizes() == grad_q.sizes());
TORCH_CHECK(query.strides() == grad_q.strides());
Expand Down Expand Up @@ -211,7 +216,7 @@ efficient_attention_backward_ck(

TORCH_CHECK(p.B == logsumexp.size(0));
TORCH_CHECK(p.Hq == logsumexp.size(1));
TORCH_CHECK(p.M <= logsumexp.size(2));
TORCH_CHECK(p.M == logsumexp.size(2));

if (scale.has_value()) {
p.scale = float(*scale);
Expand All @@ -229,6 +234,11 @@ efficient_attention_backward_ck(
p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr();
p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr();

if (use_grad_q_f32)
p.grad_q_f32_ptr = grad_q_f32.data_ptr();
else
p.grad_q_f32_ptr = nullptr;

p.q_strides = {
static_cast<int>(query.stride(0)),
static_cast<int>(query.stride(1)),
Expand Down Expand Up @@ -260,6 +270,14 @@ efficient_attention_backward_ck(
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};

if (use_grad_q_f32) {
p.grad_q_f32_strides = {
static_cast<int>(grad_q_f32.stride(0)),
static_cast<int>(grad_q_f32.stride(1)),
static_cast<int>(grad_q_f32.stride(2)),
static_cast<int>(grad_q_f32.stride(3))};
}

if (is_mqa_gqa) {
p.grad_k_strides = {
static_cast<int>(tmp_grad_k.stride(0)),
Expand Down Expand Up @@ -335,9 +353,9 @@ efficient_attention_backward_ck(
p.max_seqlen_q = *max_seqlen_q_;
p.max_seqlen_k = *max_seqlen_k_;

TORCH_CHECK(p.num_batches == logsumexp.size(0));
// unpadded lse layout required
TORCH_CHECK(p.Hq == logsumexp.size(1));
TORCH_CHECK(p.max_seqlen_q <= logsumexp.size(2));
TORCH_CHECK(p.M == logsumexp.size(2));

if (scale.has_value())
p.scale = float(*scale);
Expand Down Expand Up @@ -366,10 +384,16 @@ efficient_attention_backward_ck(
static_cast<int>(grad_out.stride(3))};

p.lsed_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};

if (use_grad_q_f32) {
p.grad_q_f32_strides = {
static_cast<int>(grad_q_f32.stride(1)),
static_cast<int>(grad_q_f32.stride(2)),
static_cast<int>(grad_q_f32.stride(3))};
}

if (is_mqa_gqa) {
p.grad_k_strides = {
static_cast<int>(tmp_grad_k.stride(1)),
Expand Down Expand Up @@ -480,6 +504,11 @@ efficient_attention_backward_ck(
p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr();
p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr();
p.grad_bias_ptr = bias_requires_grad ? grad_bias.data_ptr() : nullptr;

if (use_grad_q_f32)
p.grad_q_f32_ptr = grad_q_f32.data_ptr();
else
p.grad_q_f32_ptr = nullptr;
};

auto inDataType = query.scalar_type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ template <
int32_t ThreadsPerWavefront,
int32_t WavefrontsPerBlock,
int32_t KV_M_MAX = 8192,
int32_t K_MAX = 256>
int32_t K_MAX = K_MAX>
at::Tensor& efficient_attention_forward_decoder_ck_out_impl(
const at::Tensor& XQ, // [B, 1, G, H, D]
const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,14 @@ efficient_attention_forward_ck(
p.dropout_prob = 0.0f;

if (p.compute_logsumexp) {
// align the access of logsumexp by each thread-group in cache-line size
int aligned_seqlen_q = (p.max_seqlen_q + 15) / 16 * 16;
logsumexp = at::empty(
{p.num_batches, Hq, aligned_seqlen_q}, opts.dtype(at::kFloat));
logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat));
p.logsumexp_ptr = logsumexp.data_ptr();
p.lse_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};
} else {
p.logsumexp_ptr = nullptr;
p.lse_strides = {0, 0, 0};
p.lse_strides = {0, 0};
}
};

Expand Down
Loading

0 comments on commit e3900ba

Please sign in to comment.