diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 27f92c2d97..82b7f28692 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -384,7 +384,8 @@ class DeiTModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1000))