Add 'with torch.no_grad()' to BEiT integration test forward passes (#14961)
* Add 'with torch.no_grad()' to BEiT integration test forward pass * Fix inconsistent use of tabs and spaces in indentation
This commit is contained in:
@@ -435,6 +435,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
@@ -457,6 +458,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
@@ -482,6 +484,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
@@ -508,6 +511,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user