Add 'with torch.no_grad()' to BertGeneration integration test forward passes (#14963)
This commit is contained in:
@@ -307,7 +307,8 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase):
|
|||||||
def test_inference_no_head_absolute_embedding(self):
|
def test_inference_no_head_absolute_embedding(self):
|
||||||
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||||
output = model(input_ids)[0]
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
expected_shape = torch.Size([1, 8, 1024])
|
expected_shape = torch.Size([1, 8, 1024])
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
@@ -322,7 +323,8 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase):
|
|||||||
def test_inference_no_head_absolute_embedding(self):
|
def test_inference_no_head_absolute_embedding(self):
|
||||||
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||||
output = model(input_ids)[0]
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
expected_shape = torch.Size([1, 8, 50358])
|
expected_shape = torch.Size([1, 8, 50358])
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user