wrap forward passes with torch.no_grad() (#19413)

This commit is contained in:
Partho
2022-10-11 00:33:46 +05:30
committed by GitHub
parent c6a928cadb
commit 5f5e264a12

View File

@@ -493,6 +493,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
with torch.no_grad():
output = model(input_ids)[0] output = model(input_ids)[0]
vocab_size = 32000 vocab_size = 32000
@@ -536,6 +537,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
with torch.no_grad():
output = model(input_ids)[0] output = model(input_ids)[0]
expected_shape = torch.Size((1, 2)) expected_shape = torch.Size((1, 2))
@@ -551,6 +553,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
with torch.no_grad():
output = model(input_ids)[0] output = model(input_ids)[0]
expected_shape = torch.Size((1, 6, model.config.hidden_size)) expected_shape = torch.Size((1, 6, model.config.hidden_size))