From 0167edc8549b8c8b01fcc669d0da7ea41903b244 Mon Sep 17 00:00:00 2001 From: MrinalTyagi Date: Mon, 17 Jan 2022 17:52:41 +0530 Subject: [PATCH] Added forward pass of test_inference_image_classification_head with torch.no_grad() (#14777) --- tests/test_modeling_vit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_vit.py b/tests/test_modeling_vit.py index c24ae535a1..ac96f553e4 100644 --- a/tests/test_modeling_vit.py +++ b/tests/test_modeling_vit.py @@ -347,7 +347,8 @@ class ViTModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1000))