fix bug with padding mask + add corresponding test

This commit is contained in:
Rémi Louf
2019-10-30 11:19:58 +01:00
parent 3b0d2fa30e
commit da10de8466
2 changed files with 10 additions and 3 deletions

View File

@@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase):
build_mask(sequence, 23).numpy(), expected.numpy()
)
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 1).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self):
separator = 101
batch = torch.tensor(