From da10de8466c001dceca328dac12751abb71c65eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 30 Oct 2019 11:19:58 +0100 Subject: [PATCH] fix bug with padding mask + add corresponding test --- examples/utils_summarization.py | 6 +++--- examples/utils_summarization_test.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/utils_summarization.py b/examples/utils_summarization.py index cd8bc4bc2b..2a8f81cd36 100644 --- a/examples/utils_summarization.py +++ b/examples/utils_summarization.py @@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token): def build_mask(sequence, pad_token): """ Builds the mask. The attention mechanism will only attend to positions with value 1. """ - mask = sequence.clone() - mask[mask != pad_token] = 1 - mask[mask == pad_token] = 0 + mask = torch.ones_like(sequence) + idx_pad_tokens = (sequence == pad_token) + mask[idx_pad_tokens] = 0 return mask diff --git a/examples/utils_summarization_test.py b/examples/utils_summarization_test.py index 7a02f8fa1f..7604bd185d 100644 --- a/examples/utils_summarization_test.py +++ b/examples/utils_summarization_test.py @@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase): build_mask(sequence, 23).numpy(), expected.numpy() ) + def test_build_mask_with_padding_equal_to_one(self): + sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1]) + expected = torch.tensor([1, 1, 1, 1, 0, 0, 0]) + np.testing.assert_array_equal( + build_mask(sequence, 1).numpy(), expected.numpy() + ) + def test_compute_token_type_ids(self): separator = 101 batch = torch.tensor(