Skip to content

Commit

Permalink
fixes from fbcode (fairinternal/xformers#781)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@26ea807
  • Loading branch information
bottler authored and xFormers Bot committed Sep 11, 2023
1 parent bcfddde commit 04de99b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/run_with_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_init_file():
return init_file


class Trainer(object):
class Trainer:
def __init__(self, args):
self.args = args

Expand Down
6 changes: 3 additions & 3 deletions xformers/benchmarks/benchmark_blocksparse_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_sparsity(mask):


@dataclass
class Configuration(object):
class Configuration:
batch_size: int = 32
num_heads: int = 12
seq_length: int = 2048
Expand All @@ -126,7 +126,7 @@ def __str__(self):
return ",".join(desc)


class AttentionMask(object):
class AttentionMask:
def __init__(self, config=None):
super().__init__()
if config is None:
Expand Down Expand Up @@ -353,7 +353,7 @@ def __str__(self):
##############################################


class Experiment(object):
class Experiment:
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik):
self.mode = mode
self.dtype = dtype
Expand Down
8 changes: 6 additions & 2 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def normalize_bmhk(self) -> Tuple[int, ...]:
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
)
if self.value.dtype == torch.int32:
# Quantized K/V case, in which the last dims of Q and K/V are different
# Quantized K/V case, in which the last dims of Q and K are different.
# NB we currently don't have any implementations for quantized KV with
# SUPPORTS_DIFFERENT_VALUE_EMBED.
output_shape = tuple(self.query.shape)
else:
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
Expand Down Expand Up @@ -151,9 +153,11 @@ def validate_inputs(self) -> None:
)
H = self.query.shape[-2]
if self.query.ndim == 4: # BMHK
quantized_kv_cache = self.value.dtype == torch.int32
key_embed_dim = Kv if quantized_kv_cache else K
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, H, K)
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
Expand Down
13 changes: 13 additions & 0 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,19 @@ class BwOp(AttentionBwOpBase):

MAX_HEADDIM_SM8x = 192

@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)

# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
# is a strange embedding dimension.
if K not in {8, 16, 32, 64, 128, 256}:
reasons.append(f"Embed dim {K} not supported")

return reasons

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
Expand Down

0 comments on commit 04de99b

Please sign in to comment.