Skip to content

Commit

Permalink
[minor] follow up from facebookresearch#117, macro blocks mask inputs…
Browse files Browse the repository at this point in the history
…, not attentions (facebookresearch#119)

* follow up from facebookresearch#117, macro blocks mask inputs, not attentions
* matching unit test
  • Loading branch information
blefaudeux authored May 21, 2021
1 parent 1ee9e5c commit 5c2f3ce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
20 changes: 15 additions & 5 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

BATCH = 20
SEQ = 512

DEVICES = (
[torch.device("cpu")]
if not torch.cuda.is_available()
else [
torch.device("cuda")
] # save a bit on CI for now, we have seperate cpu and gpu jobs
)

test_configs = [
{
Expand Down Expand Up @@ -92,10 +98,14 @@


@pytest.mark.parametrize("config", test_configs)
def test_presets(config):
@pytest.mark.parametrize("device", DEVICES)
def test_presets(config, device):
# Build the model
model = xFormer.from_config(xFormerConfig(**config))
model = xFormer.from_config(xFormerConfig(**config)).to(device)

# Dummy inputs, test a forward
inputs = (torch.rand(BATCH, SEQ) * 10).abs().to(torch.int)
_ = model(inputs)
inputs = (torch.rand(BATCH, SEQ, device=device) * 10).abs().to(torch.int)

input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
_ = model(inputs, encoder_input_mask=input_mask, decoder_input_mask=input_mask)
8 changes: 4 additions & 4 deletions xformers/factory/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,23 @@ def from_config(cls, config: xFormerConfig):
def forward(
self,
inputs: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None,
decoder_attn_mask: Optional[torch.Tensor] = None,
encoder_input_mask: Optional[torch.Tensor] = None,
decoder_input_mask: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Encode to latent space if encoder is present
latent = inputs

if self.encoders:
for encoder in self.encoders:
latent = encoder(latent, encoder_attn_mask)
latent = encoder(latent, input_mask=encoder_input_mask)

# If decoder: either use the encoder ouput, or just decode, both options are possible
if self.decoders:
for decoder in self.decoders:
inputs = decoder(
target=inputs,
memory=latent,
att_mask=decoder_attn_mask,
input_mask=decoder_input_mask,
)

return inputs
Expand Down

0 comments on commit 5c2f3ce

Please sign in to comment.