format utils for summarization

This commit is contained in:
Rémi Louf
2019-10-30 11:24:12 +01:00
parent da10de8466
commit 070507df1f
2 changed files with 3 additions and 7 deletions

View File

@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = (sequence == pad_token)
idx_pad_tokens = sequence == pad_token
mask[idx_pad_tokens] = 0
return mask