From d6e920449ec26be7e15616ed6835ca96f0a99a56 Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 4 Oct 2022 19:42:03 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19274) --- tests/models/convbert/test_modeling_convbert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py index d3eb0aec4c..f2b82aaadf 100644 --- a/tests/models/convbert/test_modeling_convbert.py +++ b/tests/models/convbert/test_modeling_convbert.py @@ -444,7 +444,8 @@ class ConvBertModelIntegrationTest(unittest.TestCase): def test_inference_no_head(self): model = ConvBertModel.from_pretrained("YituTech/conv-bert-base") input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size((1, 6, 768)) self.assertEqual(output.shape, expected_shape)