From d984b10335bd590c75dc188f6be15d1c6062b5dc Mon Sep 17 00:00:00 2001 From: Tavin Turner Date: Mon, 31 Jan 2022 13:12:10 -0700 Subject: [PATCH] Add 'with torch.no_grad()' to BEiT integration test forward passes (#14961) * Add 'with torch.no_grad()' to BEiT integration test forward pass * Fix inconsistent use of tabs and spaces in indentation --- tests/test_modeling_beit.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index db8bd8c6d0..777247161a 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -435,7 +435,8 @@ class BeitModelIntegrationTest(unittest.TestCase): bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device) # forward pass - outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos) + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos) logits = outputs.logits # verify the logits @@ -457,7 +458,8 @@ class BeitModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) logits = outputs.logits # verify the logits @@ -482,7 +484,8 @@ class BeitModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) logits = outputs.logits # verify the logits @@ -508,7 +511,8 @@ class BeitModelIntegrationTest(unittest.TestCase): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) logits = outputs.logits # verify the logits