model conversion WIP
This commit is contained in:
@@ -105,7 +105,132 @@ class BertConfig(object):
|
|||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
class BERTLayerNorm(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
tf.contrib.layers.layer_norm(
|
||||||
|
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||||
|
|
||||||
class BERTEmbeddings(nn.Module):
|
class BERTEmbeddings(nn.Module):
|
||||||
|
def __init__(self, embedding_size, vocab_size,
|
||||||
|
token_type_vocab_size, max_position_embeddings,
|
||||||
|
config):
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
||||||
|
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size)
|
||||||
|
|
||||||
|
self.LayerNorm = BERTLayerNorm() # Not snake-cased to fit with TF model variable name
|
||||||
|
self.dropout = nn.dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
self.initialize_weights(self, config.initializer_range)
|
||||||
|
|
||||||
|
def initialize_weights(self, initializer_range):
|
||||||
|
torch.truncated_normal_initializer(stddev=initializer_range)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None):
|
||||||
|
batch_size = input_ids.size(0)
|
||||||
|
seq_length = input_ids.size(1)
|
||||||
|
position_ids = torch.range().view(batch_size, seq_length)
|
||||||
|
|
||||||
|
words_embeddings = self.word_embeddings(input_ids)
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
|
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BERTIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTOutput, self).__init__()
|
||||||
|
self.dense = nn.Linear()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BERTOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTOutput, self).__init__()
|
||||||
|
self.dense = nn.Linear()
|
||||||
|
self.LayerNorm = BERTLayerNorm(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BERTSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTSelfAttention, self).__init__()
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||||
|
attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
all_head_size = num_attention_heads * attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, all_head_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, all_head_size)
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x, k=False):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||||
|
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||||
|
if k:
|
||||||
|
return x.permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BERTAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTAttention, self).__init__()
|
||||||
|
self.self = BERTSelfAttention(config)
|
||||||
|
self.output = BERTOutput(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.self(hidden_states)
|
||||||
|
hidden_states = self.output(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BERTLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTLayer, self).__init__()
|
||||||
|
self.attention = BERTAttention(config)
|
||||||
|
self.intermediate = BERTIntermediate(config)
|
||||||
|
self.output = BERTOutput(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.attention(hidden_states)
|
||||||
|
hidden_states = self.intermediate(hidden_states)
|
||||||
|
hidden_states = self.output(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BERTEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BERTEncoder, self).__init__()
|
||||||
|
layer = BERTLayer(n_ctx, cfg, scale=True)
|
||||||
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||||
|
Return:
|
||||||
|
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||||
|
"""
|
||||||
|
for layer_module in self.layer:
|
||||||
|
hidden_states = layer_module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BertModel(nn.Module):
|
class BertModel(nn.Module):
|
||||||
@@ -132,28 +257,11 @@ class BertModel(nn.Module):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, config: BertConfig):
|
||||||
config,
|
|
||||||
is_training,
|
|
||||||
input_ids,
|
|
||||||
input_mask=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
use_one_hot_embeddings=True,
|
|
||||||
scope=None):
|
|
||||||
"""Constructor for BertModel.
|
"""Constructor for BertModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: `BertConfig` instance.
|
config: `BertConfig` instance.
|
||||||
is_training: bool. rue for training model, false for eval model. Controls
|
|
||||||
whether dropout will be applied.
|
|
||||||
input_ids: int32 Tensor of shape [batch_size, seq_length].
|
|
||||||
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
|
|
||||||
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
|
||||||
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
|
|
||||||
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
|
|
||||||
it is must faster if this is True, on the CPU or GPU, it is faster if
|
|
||||||
this is False.
|
|
||||||
scope: (optional) variable scope. Defaults to "bert".
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: The config is invalid or one of the input tensor shapes
|
ValueError: The config is invalid or one of the input tensor shapes
|
||||||
@@ -168,15 +276,20 @@ class BertModel(nn.Module):
|
|||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
seq_length = input_ids.size(1)
|
seq_length = input_ids.size(1)
|
||||||
|
|
||||||
|
self.embeddings = BERTEmbeddings(config)
|
||||||
|
self.encoder = BERTEncoder(config)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None, input_mask=None):
|
||||||
if input_mask is None:
|
if input_mask is None:
|
||||||
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
|
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
|
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
|
||||||
|
|
||||||
self.embeddings = BERTEmbeddings(config.vocab_size, config.hidden_size)
|
hidden_states = self.embeddings(input_ids, token_type_ids, input_mask)
|
||||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
hidden_states = self.encoder(hidden_states)
|
||||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
||||||
# Perform embedding lookup on the word ids.
|
# Perform embedding lookup on the word ids.
|
||||||
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user