fix bug with padding mask + add corresponding test

This commit is contained in:
Rémi Louf
2019-10-30 11:19:58 +01:00
parent 3b0d2fa30e
commit da10de8466
2 changed files with 10 additions and 3 deletions

View File

@@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token):
def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = sequence.clone()
mask[mask != pad_token] = 1
mask[mask == pad_token] = 0
mask = torch.ones_like(sequence)
idx_pad_tokens = (sequence == pad_token)
mask[idx_pad_tokens] = 0
return mask