2 SinusoidalPositionalEmbedding fixes (#8226)

This commit is contained in:
Stas Bekman
2020-11-02 15:50:26 -08:00
committed by GitHub
parent f744b81572
commit 504ff7bb12
2 changed files with 6 additions and 7 deletions

View File

@@ -620,8 +620,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
def test_odd_embed_dim(self):
with self.assertRaises(NotImplementedError):
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd embedding_dim is allowed
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd num_positions is allowed
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)