From c5a94a6100afdd550fb3ea445d8bddc6b9769fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 12:50:36 +0200 Subject: [PATCH] fix function that defines masks in XLM the definition of `get_masks` would blow with the proper combination of arguments. It was just a matter of moving a definition outside of a control structure. --- transformers/modeling_xlm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index b29e721556..f1df6f668f 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -73,16 +73,16 @@ def get_masks(slen, lengths, causal, padding_mask=None): """ Generate hidden states mask, and optionally an attention mask. """ - bs = lengths.size(0) + alen = torch.arange(slen, dtype=torch.long, device=lengths.device) if padding_mask is not None: mask = padding_mask else: assert lengths.max().item() <= slen - alen = torch.arange(slen, dtype=torch.long, device=lengths.device) mask = alen < lengths[:, None] # attention mask is the same as mask, or triangular inferior attention (causal) if causal: + bs = lengths.size(0) attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] else: attn_mask = mask