wrap forward passes with torch.no_grad() (#19273)
This commit is contained in:
@@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device)
|
input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device)
|
||||||
outputs = model(input_ids)
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
prediction_logits = outputs.prediction_logits
|
prediction_logits = outputs.prediction_logits
|
||||||
seq_relationship_logits = outputs.seq_relationship_logits
|
seq_relationship_logits = outputs.seq_relationship_logits
|
||||||
|
|
||||||
@@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device)
|
input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device)
|
||||||
outputs = model(input_ids)
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
prediction_logits = outputs.prediction_logits
|
prediction_logits = outputs.prediction_logits
|
||||||
seq_relationship_logits = outputs.seq_relationship_logits
|
seq_relationship_logits = outputs.seq_relationship_logits
|
||||||
|
|
||||||
@@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long)
|
input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long)
|
||||||
output = model(input_ids).to_tuple()[0]
|
with torch.no_grad():
|
||||||
|
output = model(input_ids).to_tuple()[0]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
target = torch.tensor(
|
target = torch.tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user