format utils for summarization

This commit is contained in:
Rémi Louf
2019-10-30 11:24:12 +01:00
parent da10de8466
commit 070507df1f
2 changed files with 3 additions and 7 deletions

View File

@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal(
build_mask(sequence, 0).numpy(), expected.numpy()
)
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 1).numpy(), expected.numpy()
)
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
def test_compute_token_type_ids(self):
separator = 101