Allow gpt2 to be exported to valid ONNX (#4244)
* allow gpt2 to be exported to valid ONNX model * cast size from int to float explictly
This commit is contained in:
@@ -26,7 +26,7 @@ def gelu_new(x):
|
|||||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||||
Also see https://arxiv.org/abs/1606.08415
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
"""
|
"""
|
||||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
|
|
||||||
if torch.__version__ < "1.4.0":
|
if torch.__version__ < "1.4.0":
|
||||||
@@ -36,7 +36,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def gelu_fast(x):
|
def gelu_fast(x):
|
||||||
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
||||||
|
|
||||||
|
|
||||||
ACT2FN = {
|
ACT2FN = {
|
||||||
|
|||||||
@@ -142,10 +142,10 @@ class Attention(nn.Module):
|
|||||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
if self.scale:
|
if self.scale:
|
||||||
w = w / (v.size(-1) ** 0.5)
|
w = w / (float(v.size(-1)) ** 0.5)
|
||||||
nd, ns = w.size(-2), w.size(-1)
|
nd, ns = w.size(-2), w.size(-1)
|
||||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||||
w = torch.where(mask, w, self.masked_bias.to(w.dtype))
|
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
|
|||||||
Reference in New Issue
Block a user