format utils for summarization
This commit is contained in:
@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
|
|||||||
""" Builds the mask. The attention mechanism will only attend to positions
|
""" Builds the mask. The attention mechanism will only attend to positions
|
||||||
with value 1. """
|
with value 1. """
|
||||||
mask = torch.ones_like(sequence)
|
mask = torch.ones_like(sequence)
|
||||||
idx_pad_tokens = (sequence == pad_token)
|
idx_pad_tokens = sequence == pad_token
|
||||||
mask[idx_pad_tokens] = 0
|
mask[idx_pad_tokens] = 0
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|||||||
@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
def test_build_mask_no_padding(self):
|
def test_build_mask_no_padding(self):
|
||||||
sequence = torch.tensor([1, 2, 3, 4])
|
sequence = torch.tensor([1, 2, 3, 4])
|
||||||
expected = torch.tensor([1, 1, 1, 1])
|
expected = torch.tensor([1, 1, 1, 1])
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
|
||||||
build_mask(sequence, 0).numpy(), expected.numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_build_mask(self):
|
def test_build_mask(self):
|
||||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||||
@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
def test_build_mask_with_padding_equal_to_one(self):
|
def test_build_mask_with_padding_equal_to_one(self):
|
||||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
|
||||||
build_mask(sequence, 1).numpy(), expected.numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_compute_token_type_ids(self):
|
def test_compute_token_type_ids(self):
|
||||||
separator = 101
|
separator = 101
|
||||||
|
|||||||
Reference in New Issue
Block a user