2 SinusoidalPositionalEmbedding fixes (#8226)
This commit is contained in:
@@ -1328,8 +1328,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
|
|||||||
|
|
||||||
def __init__(self, num_positions, embedding_dim, padding_idx=None):
|
def __init__(self, num_positions, embedding_dim, padding_idx=None):
|
||||||
super().__init__(num_positions, embedding_dim)
|
super().__init__(num_positions, embedding_dim)
|
||||||
if embedding_dim % 2 != 0:
|
|
||||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
|
||||||
self.weight = self._init_weight(self.weight)
|
self.weight = self._init_weight(self.weight)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1342,10 +1340,11 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
|
|||||||
position_enc = np.array(
|
position_enc = np.array(
|
||||||
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
||||||
)
|
)
|
||||||
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
|
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
|
||||||
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
|
||||||
|
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
|
||||||
|
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
||||||
out.detach_()
|
out.detach_()
|
||||||
out.requires_grad = False
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -620,8 +620,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
|||||||
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
|
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
|
||||||
|
|
||||||
def test_odd_embed_dim(self):
|
def test_odd_embed_dim(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
# odd embedding_dim is allowed
|
||||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
||||||
|
|
||||||
# odd num_positions is allowed
|
# odd num_positions is allowed
|
||||||
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
|
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user