From 7ba83730c48f80e932450da86ec601131d0f3679 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:31:20 +0100 Subject: [PATCH] clean up pr --- convert_tf_checkpoint_to_pytorch.py | 12 +++++++++--- modeling.py | 8 ++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index dfcdbee42d..eeebb3728e 100755 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -68,11 +68,17 @@ def convert(): arrays.append(array) for name, array in zip(names, arrays): - name = name[5:] # skip "bert/" + if not name.startswith("bert"): + print("Skipping {}".format(name)) + continue + else: + name = name.replace("bert/", "") # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - if name[0] in ['redictions', 'eq_relationship']: - print("Skipping") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": + print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: diff --git a/modeling.py b/modeling.py index 9874b3d5df..3b3f198c92 100644 --- a/modeling.py +++ b/modeling.py @@ -26,6 +26,10 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss + +ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} + + def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): @@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module): def __init__(self, config): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - act2fn = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} - self.intermediate_act_fn = act2fn[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states)