From f134d385535d10ee4c0950223e6ddfdc109c99df Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 4 Oct 2022 19:38:29 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19279) --- tests/models/deit/test_modeling_deit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 27f92c2d97..82b7f28692 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -384,7 +384,8 @@ class DeiTModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1000))