diff --git a/xformers/benchmarks/LRA/run_with_submitit.py b/xformers/benchmarks/LRA/run_with_submitit.py index 13945aac6..6fefdd1f5 100644 --- a/xformers/benchmarks/LRA/run_with_submitit.py +++ b/xformers/benchmarks/LRA/run_with_submitit.py @@ -70,7 +70,7 @@ def get_init_file(): return init_file -class Trainer(object): +class Trainer: def __init__(self, args): self.args = args diff --git a/xformers/benchmarks/benchmark_blocksparse_transformers.py b/xformers/benchmarks/benchmark_blocksparse_transformers.py index e50b5a7d6..f9cb72a15 100644 --- a/xformers/benchmarks/benchmark_blocksparse_transformers.py +++ b/xformers/benchmarks/benchmark_blocksparse_transformers.py @@ -100,7 +100,7 @@ def get_sparsity(mask): @dataclass -class Configuration(object): +class Configuration: batch_size: int = 32 num_heads: int = 12 seq_length: int = 2048 @@ -126,7 +126,7 @@ def __str__(self): return ",".join(desc) -class AttentionMask(object): +class AttentionMask: def __init__(self, config=None): super().__init__() if config is None: @@ -353,7 +353,7 @@ def __str__(self): ############################################## -class Experiment(object): +class Experiment: def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik): self.mode = mode self.dtype = dtype diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index a1772b759..8e7abe307 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -57,7 +57,9 @@ def normalize_bmhk(self) -> Tuple[int, ...]: ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." ) if self.value.dtype == torch.int32: - # Quantized K/V case, in which the last dims of Q and K/V are different + # Quantized K/V case, in which the last dims of Q and K are different. + # NB we currently don't have any implementations for quantized KV with + # SUPPORTS_DIFFERENT_VALUE_EMBED. output_shape = tuple(self.query.shape) else: output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) @@ -151,9 +153,11 @@ def validate_inputs(self) -> None: ) H = self.query.shape[-2] if self.query.ndim == 4: # BMHK + quantized_kv_cache = self.value.dtype == torch.int32 + key_embed_dim = Kv if quantized_kv_cache else K valid_shapes = ( self.query.shape == (B, Mq, H, K) - and self.key.shape == (B, Mkv, H, K) + and self.key.shape == (B, Mkv, H, key_embed_dim) and self.value.shape == (B, Mkv, H, Kv) ) G = self.query.shape[2] diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 7535cd54b..743df2215 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -395,6 +395,19 @@ class BwOp(AttentionBwOpBase): MAX_HEADDIM_SM8x = 192 + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + + # In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there + # is a strange embedding dimension. + if K not in {8, 16, 32, 64, 128, 256}: + reasons.append(f"Embed dim {K} not supported") + + return reasons + @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d)