Add 'with torch.no_grad()' to BEiT integration test forward passes (#14961)
* Add 'with torch.no_grad()' to BEiT integration test forward pass * Fix inconsistent use of tabs and spaces in indentation
This commit is contained in:
@@ -435,7 +435,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
@@ -457,7 +458,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
@@ -482,7 +484,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
@@ -508,7 +511,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
|
||||
Reference in New Issue
Block a user