From 04de99bb28aa6de8d48fab3cdbbc9e3874c994b8 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein <669761+bottler@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:07:15 +0000 Subject: [PATCH] fixes from fbcode (fairinternal/xformers#781) __original_commit__ = fairinternal/xformers@26ea807035920ece322176b478219b44b76ffc0e --- xformers/benchmarks/LRA/run_with_submitit.py | 2 +- .../benchmark_blocksparse_transformers.py | 6 +++--- xformers/ops/fmha/common.py | 8 ++++++-- xformers/ops/fmha/flash.py | 13 +++++++++++++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/xformers/benchmarks/LRA/run_with_submitit.py b/xformers/benchmarks/LRA/run_with_submitit.py index 13945aac6..6fefdd1f5 100644 --- a/xformers/benchmarks/LRA/run_with_submitit.py +++ b/xformers/benchmarks/LRA/run_with_submitit.py @@ -70,7 +70,7 @@ def get_init_file(): return init_file -class Trainer(object): +class Trainer: def __init__(self, args): self.args = args diff --git a/xformers/benchmarks/benchmark_blocksparse_transformers.py b/xformers/benchmarks/benchmark_blocksparse_transformers.py index e50b5a7d6..f9cb72a15 100644 --- a/xformers/benchmarks/benchmark_blocksparse_transformers.py +++ b/xformers/benchmarks/benchmark_blocksparse_transformers.py @@ -100,7 +100,7 @@ def get_sparsity(mask): @dataclass -class Configuration(object): +class Configuration: batch_size: int = 32 num_heads: int = 12 seq_length: int = 2048 @@ -126,7 +126,7 @@ def __str__(self): return ",".join(desc) -class AttentionMask(object): +class AttentionMask: def __init__(self, config=None): super().__init__() if config is None: @@ -353,7 +353,7 @@ def __str__(self): ############################################## -class Experiment(object): +class Experiment: def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik): self.mode = mode self.dtype = dtype diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index a1772b759..8e7abe307 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -57,7 +57,9 @@ def normalize_bmhk(self) -> Tuple[int, ...]: ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." ) if self.value.dtype == torch.int32: - # Quantized K/V case, in which the last dims of Q and K/V are different + # Quantized K/V case, in which the last dims of Q and K are different. + # NB we currently don't have any implementations for quantized KV with + # SUPPORTS_DIFFERENT_VALUE_EMBED. output_shape = tuple(self.query.shape) else: output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) @@ -151,9 +153,11 @@ def validate_inputs(self) -> None: ) H = self.query.shape[-2] if self.query.ndim == 4: # BMHK + quantized_kv_cache = self.value.dtype == torch.int32 + key_embed_dim = Kv if quantized_kv_cache else K valid_shapes = ( self.query.shape == (B, Mq, H, K) - and self.key.shape == (B, Mkv, H, K) + and self.key.shape == (B, Mkv, H, key_embed_dim) and self.value.shape == (B, Mkv, H, Kv) ) G = self.query.shape[2] diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 7535cd54b..743df2215 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -395,6 +395,19 @@ class BwOp(AttentionBwOpBase): MAX_HEADDIM_SM8x = 192 + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + + # In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there + # is a strange embedding dimension. + if K not in {8, 16, 32, 64, 128, 256}: + reasons.append(f"Embed dim {K} not supported") + + return reasons + @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d)