fix bug with padding mask + add corresponding test

This commit is contained in:
Rémi Louf
2019-10-30 11:19:58 +01:00
parent 3b0d2fa30e
commit da10de8466
2 changed files with 10 additions and 3 deletions

View File

@@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token):
def build_mask(sequence, pad_token): def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
mask = sequence.clone() mask = torch.ones_like(sequence)
mask[mask != pad_token] = 1 idx_pad_tokens = (sequence == pad_token)
mask[mask == pad_token] = 0 mask[idx_pad_tokens] = 0
return mask return mask

View File

@@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase):
build_mask(sequence, 23).numpy(), expected.numpy() build_mask(sequence, 23).numpy(), expected.numpy()
) )
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 1).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self): def test_compute_token_type_ids(self):
separator = 101 separator = 101
batch = torch.tensor( batch = torch.tensor(