fix function that defines masks in XLM
the definition of `get_masks` would blow with the proper combination of arguments. It was just a matter of moving a definition outside of a control structure.
This commit is contained in:
@@ -73,16 +73,16 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
"""
|
"""
|
||||||
Generate hidden states mask, and optionally an attention mask.
|
Generate hidden states mask, and optionally an attention mask.
|
||||||
"""
|
"""
|
||||||
bs = lengths.size(0)
|
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
mask = padding_mask
|
mask = padding_mask
|
||||||
else:
|
else:
|
||||||
assert lengths.max().item() <= slen
|
assert lengths.max().item() <= slen
|
||||||
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
|
|
||||||
mask = alen < lengths[:, None]
|
mask = alen < lengths[:, None]
|
||||||
|
|
||||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||||
if causal:
|
if causal:
|
||||||
|
bs = lengths.size(0)
|
||||||
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
|
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
|
||||||
else:
|
else:
|
||||||
attn_mask = mask
|
attn_mask = mask
|
||||||
|
|||||||
Reference in New Issue
Block a user