From 5f5e264a12956bd7cce47dcb422b80ed68e4c24e Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 11 Oct 2022 00:33:46 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19413) --- tests/models/fnet/test_modeling_fnet.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index 974d7c2d4e..5d975b061f 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -493,7 +493,8 @@ class FNetModelIntegrationTest(unittest.TestCase): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] vocab_size = 32000 @@ -536,7 +537,8 @@ class FNetModelIntegrationTest(unittest.TestCase): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size((1, 2)) self.assertEqual(output.shape, expected_shape) @@ -551,7 +553,8 @@ class FNetModelIntegrationTest(unittest.TestCase): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size((1, 6, model.config.hidden_size)) self.assertEqual(output.shape, expected_shape)