Skip to content

Commit

Permalink
Remove un-used test
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Aug 21, 2024
1 parent 3cf5721 commit 01cc08e
Showing 1 changed file with 0 additions and 32 deletions.
32 changes: 0 additions & 32 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,38 +998,6 @@ def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p):
)


@cuda_only
@disable_tf32
@disable_on_rocm
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("kv_len", [3 * 32])
@pytest.mark.parametrize("q_len", [3 * 32])
def test_memory_efficient_attention_full_block_masked(q_len, kv_len, batch_size, k_len):
device = "cuda"
op_fw = fmha.small_k.FwOp
op_bw = fmha.small_k.BwOp

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

# in this case, most of the blocks in a row get masked
attn_bias = torch.full((3, 32), float("-inf"), device=device)
attn_bias[:2, :4] = 0
attn_bias = attn_bias.flatten()[None, None, :].expand(1, q_len, -1)

out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
)
ref = ref_attention_for_test(query, key, value, attn_bias)

assert_allclose(
out, ref, atol=op_fw.ERROR_ATOL[query.dtype], rtol=op_fw.ERROR_RTOL[query.dtype]
)


@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt):
Expand Down

0 comments on commit 01cc08e

Please sign in to comment.