wrap forward passes with torch.no_grad() (#19413)
This commit is contained in:
@@ -493,7 +493,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 32000
|
||||
|
||||
@@ -536,7 +537,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 2))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
@@ -551,7 +553,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 6, model.config.hidden_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user