From a9782881a40ecc905a658b6cd3e561548d78c8ec Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 4 Oct 2022 19:43:22 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19273) --- tests/models/big_bird/test_modeling_big_bird.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index ec59f8f93d..ec8705607d 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): model.to(torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device) - outputs = model(input_ids) + with torch.no_grad(): + outputs = model(input_ids) prediction_logits = outputs.prediction_logits seq_relationship_logits = outputs.seq_relationship_logits @@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): model.to(torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device) - outputs = model(input_ids) + with torch.no_grad(): + outputs = model(input_ids) prediction_logits = outputs.prediction_logits seq_relationship_logits = outputs.seq_relationship_logits @@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): model.eval() input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long) - output = model(input_ids).to_tuple()[0] + with torch.no_grad(): + output = model(input_ids).to_tuple()[0] # fmt: off target = torch.tensor(