Added with torch.no_grad() to Camembert integration test (#21544)

add with torch.no_grad() to Camembert integration test

Co-authored-by: Bibi <Bibi@katies-mac.local>
This commit is contained in:
Katie Le
2023-02-10 04:58:29 -05:00
committed by GitHub
parent f83942684d
commit 21a2d900ec

View File

@@ -39,7 +39,8 @@ class CamembertModelIntegrationTest(unittest.TestCase):
device=torch_device, device=torch_device,
dtype=torch.long, dtype=torch.long,
) # J'aime le camembert ! ) # J'aime le camembert !
output = model(input_ids)["last_hidden_state"] with torch.no_grad():
output = model(input_ids)["last_hidden_state"]
expected_shape = torch.Size((1, 10, 768)) expected_shape = torch.Size((1, 10, 768))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice. # compare the actual values for a slice.