Fix BART tests on GPU (#4298)

This commit is contained in:
Julien Chaumond
2020-05-12 09:11:50 -04:00
committed by GitHub
parent e4512aab3b
commit 4bf5042240
2 changed files with 6 additions and 2 deletions

View File

@@ -690,4 +690,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
# test that forward pass is just a lookup, there is no ignore padding logic
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
no_cache_pad_zero = emb1(input_ids)
self.assertTrue(torch.allclose(torch.Tensor(self.desired_weights), no_cache_pad_zero[:3, :5], atol=1e-3))
self.assertTrue(
torch.allclose(
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
)
)