model conversion WIP
This commit is contained in:
@@ -105,7 +105,132 @@ class BertConfig(object):
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
||||
class BERTLayerNorm(nn.Module):
|
||||
def __init__(self):
|
||||
tf.contrib.layers.layer_norm(
|
||||
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||
|
||||
class BERTEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_size, vocab_size,
|
||||
token_type_vocab_size, max_position_embeddings,
|
||||
config):
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size)
|
||||
|
||||
self.LayerNorm = BERTLayerNorm() # Not snake-cased to fit with TF model variable name
|
||||
self.dropout = nn.dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.initialize_weights(self, config.initializer_range)
|
||||
|
||||
def initialize_weights(self, initializer_range):
|
||||
torch.truncated_normal_initializer(stddev=initializer_range)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
batch_size = input_ids.size(0)
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.range().view(batch_size, seq_length)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BERTIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTOutput, self).__init__()
|
||||
self.dense = nn.Linear()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTOutput, self).__init__()
|
||||
self.dense = nn.Linear()
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
all_head_size = num_attention_heads * attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, all_head_size)
|
||||
|
||||
def transpose_for_scores(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTAttention, self).__init__()
|
||||
self.self = BERTSelfAttention(config)
|
||||
self.output = BERTOutput(config)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.self(hidden_states)
|
||||
hidden_states = self.output(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTLayer, self).__init__()
|
||||
self.attention = BERTAttention(config)
|
||||
self.intermediate = BERTIntermediate(config)
|
||||
self.output = BERTOutput(config)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.attention(hidden_states)
|
||||
hidden_states = self.intermediate(hidden_states)
|
||||
hidden_states = self.output(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTEncoder, self).__init__()
|
||||
layer = BERTLayer(n_ctx, cfg, scale=True)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
Return:
|
||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
"""
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
@@ -132,28 +257,11 @@ class BertModel(nn.Module):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
input_ids,
|
||||
input_mask=None,
|
||||
token_type_ids=None,
|
||||
use_one_hot_embeddings=True,
|
||||
scope=None):
|
||||
def __init__(self, config: BertConfig):
|
||||
"""Constructor for BertModel.
|
||||
|
||||
Args:
|
||||
config: `BertConfig` instance.
|
||||
is_training: bool. rue for training model, false for eval model. Controls
|
||||
whether dropout will be applied.
|
||||
input_ids: int32 Tensor of shape [batch_size, seq_length].
|
||||
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
|
||||
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
||||
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
|
||||
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
|
||||
it is must faster if this is True, on the CPU or GPU, it is faster if
|
||||
this is False.
|
||||
scope: (optional) variable scope. Defaults to "bert".
|
||||
|
||||
Raises:
|
||||
ValueError: The config is invalid or one of the input tensor shapes
|
||||
@@ -168,15 +276,20 @@ class BertModel(nn.Module):
|
||||
batch_size = input_ids.size(0)
|
||||
seq_length = input_ids.size(1)
|
||||
|
||||
if input_mask is None:
|
||||
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
|
||||
self.embeddings = BERTEmbeddings(config)
|
||||
self.encoder = BERTEncoder(config)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
|
||||
|
||||
self.embeddings = BERTEmbeddings(config.vocab_size, config.hidden_size)
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None):
|
||||
if input_mask is None:
|
||||
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
|
||||
|
||||
hidden_states = self.embeddings(input_ids, token_type_ids, input_mask)
|
||||
hidden_states = self.encoder(hidden_states)
|
||||
|
||||
# Perform embedding lookup on the word ids.
|
||||
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
||||
input_ids=input_ids,
|
||||
|
||||
Reference in New Issue
Block a user