2 SinusoidalPositionalEmbedding fixes (#8226)
This commit is contained in:
@@ -620,8 +620,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
|
||||
|
||||
def test_odd_embed_dim(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
||||
# odd embedding_dim is allowed
|
||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
||||
|
||||
# odd num_positions is allowed
|
||||
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user