convertion script WIP

This commit is contained in:
thomwolf
2018-11-01 18:00:20 +01:00
parent 5581edb4f6
commit ab0e8932a8
2 changed files with 25 additions and 20 deletions

View File

@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module):
class BERTEmbeddings(nn.Module):
def __init__(self, config):
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
super(BERTEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
# Position embeddings are (normally) a contiguous range so we could use a slice
# Since the position embedding table is a learned variable, we create it
@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module):
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm() # Not snake-cased to stick with TF model variable name
self.LayerNorm = BERTLayerNorm(config) # Not snake-cased to stick with TF model variable name
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, input_tensor, num_attention_heads, is_key_tensor=False):
def transpose_for_scores(self, x, is_key_tensor=False):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
if is_key_tensor:
@@ -270,7 +270,7 @@ class BERTAttention(nn.Module):
class BERTIntermediate(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = gelu
@@ -305,13 +305,13 @@ class BERTLayer(nn.Module):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return hidden_states
return layer_output
class BERTEncoder(nn.Module):
def __init__(self, config):
super(BERTEncoder, self).__init__()
layer = BERTLayer(n_ctx, cfg, scale=True)
layer = BERTLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask):
@@ -383,7 +383,7 @@ class BertModel(nn.Module):
ValueError: The config is invalid or one of the input tensor shapes
is invalid.
"""
super(BertModel).__init__()
super(BertModel, self).__init__()
self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config)