Allow gpt2 to be exported to valid ONNX (#4244)

* allow gpt2 to be exported to valid ONNX model

* cast size from int to float explictly
This commit is contained in:
Tianlei Wu
2020-05-11 11:55:55 -07:00
committed by GitHub
parent 39994051e4
commit 82601f4c1a
2 changed files with 4 additions and 4 deletions

View File

@@ -26,7 +26,7 @@ def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if torch.__version__ < "1.4.0":
@@ -36,7 +36,7 @@ else:
def gelu_fast(x):
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
ACT2FN = {

View File

@@ -142,10 +142,10 @@ class Attention(nn.Module):
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / (v.size(-1) ** 0.5)
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias.to(w.dtype))
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None:
# Apply the attention mask