format utils for summarization

This commit is contained in:
Rémi Louf
2019-10-30 11:24:12 +01:00
parent da10de8466
commit 070507df1f
2 changed files with 3 additions and 7 deletions

View File

@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
mask = torch.ones_like(sequence) mask = torch.ones_like(sequence)
idx_pad_tokens = (sequence == pad_token) idx_pad_tokens = sequence == pad_token
mask[idx_pad_tokens] = 0 mask[idx_pad_tokens] = 0
return mask return mask

View File

@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_no_padding(self): def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4]) sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1]) expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal( np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
build_mask(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask(self): def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23]) sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_with_padding_equal_to_one(self): def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1]) sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0]) expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal( np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
build_mask(sequence, 1).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self): def test_compute_token_type_ids(self):
separator = 101 separator = 101