From 070507df1ffd7609f4691089f0bbc7ac27df66fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 30 Oct 2019 11:24:12 +0100 Subject: [PATCH] format utils for summarization --- examples/utils_summarization.py | 2 +- examples/utils_summarization_test.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/utils_summarization.py b/examples/utils_summarization.py index 2a8f81cd36..327ca8cc3e 100644 --- a/examples/utils_summarization.py +++ b/examples/utils_summarization.py @@ -128,7 +128,7 @@ def build_mask(sequence, pad_token): """ Builds the mask. The attention mechanism will only attend to positions with value 1. """ mask = torch.ones_like(sequence) - idx_pad_tokens = (sequence == pad_token) + idx_pad_tokens = sequence == pad_token mask[idx_pad_tokens] = 0 return mask diff --git a/examples/utils_summarization_test.py b/examples/utils_summarization_test.py index 7604bd185d..1d56ff0803 100644 --- a/examples/utils_summarization_test.py +++ b/examples/utils_summarization_test.py @@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): def test_build_mask_no_padding(self): sequence = torch.tensor([1, 2, 3, 4]) expected = torch.tensor([1, 1, 1, 1]) - np.testing.assert_array_equal( - build_mask(sequence, 0).numpy(), expected.numpy() - ) + np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy()) def test_build_mask(self): sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23]) @@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): def test_build_mask_with_padding_equal_to_one(self): sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1]) expected = torch.tensor([1, 1, 1, 1, 0, 0, 0]) - np.testing.assert_array_equal( - build_mask(sequence, 1).numpy(), expected.numpy() - ) + np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy()) def test_compute_token_type_ids(self): separator = 101