From f71fb5c36e739d8224419bb091b4c16531df829f Mon Sep 17 00:00:00 2001 From: Tavin Turner Date: Thu, 6 Jan 2022 08:39:13 -0700 Subject: [PATCH] Add 'with torch.no_grad()' to BertGeneration integration test forward passes (#14963) --- tests/test_modeling_bert_generation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_bert_generation.py b/tests/test_modeling_bert_generation.py index c43e87d7fb..ea184af03b 100755 --- a/tests/test_modeling_bert_generation.py +++ b/tests/test_modeling_bert_generation.py @@ -307,7 +307,8 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase): def test_inference_no_head_absolute_embedding(self): model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size([1, 8, 1024]) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor( @@ -322,7 +323,8 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase): def test_inference_no_head_absolute_embedding(self): model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size([1, 8, 50358]) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor(