diff --git a/convert_tf_checkpoint.py b/convert_tf_checkpoint.py index bd8ddd754f..b8fb49bf1f 100644 --- a/convert_tf_checkpoint.py +++ b/convert_tf_checkpoint.py @@ -10,7 +10,7 @@ import argparse import tensorflow as tf import torch -from .modeling_pytorch import BertConfig, BertModel +from modeling_pytorch import BertConfig, BertModel parser = argparse.ArgumentParser() @@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path", args = parser.parse_args() def convert(): + # Initialise PyTorch model + config = BertConfig.from_json_file(args.bert_config_file) + model = BertModel(config) + # Load weights from TF model path = args.tf_checkpoint_path print("Converting TensorFlow checkpoint from {}".format(path)) @@ -49,24 +53,26 @@ def convert(): names.append(name) arrays.append(array) - # Initialise PyTorch model and fill weights-in - config = BertConfig.from_json_file(args.bert_config_file) - model = BertModel(config) for name, array in zip(names, arrays): name = name[5:] # skip "bert/" - assert name[-2:] == ":0" - name = name[:-2] name = name.split('/') pointer = model for m_name in name: - if re.fullmatch(r'[A-Za-z]+\d+', m_name): - l = re.split(r'(\d+)', m_name) + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) else: l = [m_name] - pointer = getattr(pointer, l[0]) + if l[0] == 'kernel': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) if len(l) >= 2: num = int(l[1]) pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + # elif m_name == 'kernel': + # pointer = getattr(pointer, 'weight') try: assert pointer.shape == array.shape except AssertionError as e: @@ -79,4 +85,3 @@ def convert(): if __name__ == "__main__": convert() - return None diff --git a/modeling_pytorch.py b/modeling_pytorch.py index f3ae8ce77f..9f2b84e911 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module): class BERTEmbeddings(nn.Module): def __init__(self, config): - - self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size) + super(BERTEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) # Position embeddings are (normally) a contiguous range so we could use a slice # Since the position embedding table is a learned variable, we create it @@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module): # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup. - self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.LayerNorm = BERTLayerNorm() # Not snake-cased to stick with TF model variable name + self.LayerNorm = BERTLayerNorm(config) # Not snake-cased to stick with TF model variable name self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): @@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, input_tensor, num_attention_heads, is_key_tensor=False): + def transpose_for_scores(self, x, is_key_tensor=False): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) if is_key_tensor: @@ -270,7 +270,7 @@ class BERTAttention(nn.Module): class BERTIntermediate(nn.Module): def __init__(self, config): - super(BERTOutput, self).__init__() + super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = gelu @@ -305,13 +305,13 @@ class BERTLayer(nn.Module): attention_output = self.attention(hidden_states, attention_mask) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - return hidden_states + return layer_output class BERTEncoder(nn.Module): def __init__(self, config): super(BERTEncoder, self).__init__() - layer = BERTLayer(n_ctx, cfg, scale=True) + layer = BERTLayer(config) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask): @@ -383,7 +383,7 @@ class BertModel(nn.Module): ValueError: The config is invalid or one of the input tensor shapes is invalid. """ - super(BertModel).__init__() + super(BertModel, self).__init__() self.embeddings = BERTEmbeddings(config) self.encoder = BERTEncoder(config) self.pooler = BERTPooler(config)