convertion script WIP
This commit is contained in:
@@ -10,7 +10,7 @@ import argparse
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .modeling_pytorch import BertConfig, BertModel
|
from modeling_pytorch import BertConfig, BertModel
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path",
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def convert():
|
def convert():
|
||||||
|
# Initialise PyTorch model
|
||||||
|
config = BertConfig.from_json_file(args.bert_config_file)
|
||||||
|
model = BertModel(config)
|
||||||
|
|
||||||
# Load weights from TF model
|
# Load weights from TF model
|
||||||
path = args.tf_checkpoint_path
|
path = args.tf_checkpoint_path
|
||||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||||
@@ -49,24 +53,26 @@ def convert():
|
|||||||
names.append(name)
|
names.append(name)
|
||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
# Initialise PyTorch model and fill weights-in
|
|
||||||
config = BertConfig.from_json_file(args.bert_config_file)
|
|
||||||
model = BertModel(config)
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name[5:] # skip "bert/"
|
name = name[5:] # skip "bert/"
|
||||||
assert name[-2:] == ":0"
|
|
||||||
name = name[:-2]
|
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||||
l = re.split(r'(\d+)', m_name)
|
l = re.split(r'_(\d+)', m_name)
|
||||||
else:
|
else:
|
||||||
l = [m_name]
|
l = [m_name]
|
||||||
|
if l[0] == 'kernel':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
else:
|
||||||
pointer = getattr(pointer, l[0])
|
pointer = getattr(pointer, l[0])
|
||||||
if len(l) >= 2:
|
if len(l) >= 2:
|
||||||
num = int(l[1])
|
num = int(l[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
|
if m_name[-11:] == '_embeddings':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
# elif m_name == 'kernel':
|
||||||
|
# pointer = getattr(pointer, 'weight')
|
||||||
try:
|
try:
|
||||||
assert pointer.shape == array.shape
|
assert pointer.shape == array.shape
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
@@ -79,4 +85,3 @@ def convert():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert()
|
convert()
|
||||||
return None
|
|
||||||
|
|||||||
@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module):
|
|||||||
|
|
||||||
class BERTEmbeddings(nn.Module):
|
class BERTEmbeddings(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
super(BERTEmbeddings, self).__init__()
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
|
||||||
# Position embeddings are (normally) a contiguous range so we could use a slice
|
# Position embeddings are (normally) a contiguous range so we could use a slice
|
||||||
# Since the position embedding table is a learned variable, we create it
|
# Since the position embedding table is a learned variable, we create it
|
||||||
@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module):
|
|||||||
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
||||||
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
||||||
# perform a slice.
|
# perform a slice.
|
||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
|
||||||
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
|
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
|
||||||
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size)
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||||
|
|
||||||
self.LayerNorm = BERTLayerNorm() # Not snake-cased to stick with TF model variable name
|
self.LayerNorm = BERTLayerNorm(config) # Not snake-cased to stick with TF model variable name
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None):
|
def forward(self, input_ids, token_type_ids=None):
|
||||||
@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
def transpose_for_scores(self, input_tensor, num_attention_heads, is_key_tensor=False):
|
def transpose_for_scores(self, x, is_key_tensor=False):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
if is_key_tensor:
|
if is_key_tensor:
|
||||||
@@ -270,7 +270,7 @@ class BERTAttention(nn.Module):
|
|||||||
|
|
||||||
class BERTIntermediate(nn.Module):
|
class BERTIntermediate(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BERTOutput, self).__init__()
|
super(BERTIntermediate, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.intermediate_act_fn = gelu
|
self.intermediate_act_fn = gelu
|
||||||
|
|
||||||
@@ -305,13 +305,13 @@ class BERTLayer(nn.Module):
|
|||||||
attention_output = self.attention(hidden_states, attention_mask)
|
attention_output = self.attention(hidden_states, attention_mask)
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
return hidden_states
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class BERTEncoder(nn.Module):
|
class BERTEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BERTEncoder, self).__init__()
|
super(BERTEncoder, self).__init__()
|
||||||
layer = BERTLayer(n_ctx, cfg, scale=True)
|
layer = BERTLayer(config)
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
@@ -383,7 +383,7 @@ class BertModel(nn.Module):
|
|||||||
ValueError: The config is invalid or one of the input tensor shapes
|
ValueError: The config is invalid or one of the input tensor shapes
|
||||||
is invalid.
|
is invalid.
|
||||||
"""
|
"""
|
||||||
super(BertModel).__init__()
|
super(BertModel, self).__init__()
|
||||||
self.embeddings = BERTEmbeddings(config)
|
self.embeddings = BERTEmbeddings(config)
|
||||||
self.encoder = BERTEncoder(config)
|
self.encoder = BERTEncoder(config)
|
||||||
self.pooler = BERTPooler(config)
|
self.pooler = BERTPooler(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user