diff --git a/modeling_pytorch.py b/modeling_pytorch.py index dc3f3e63e4..0ee84c8f6a 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -105,7 +105,132 @@ class BertConfig(object): return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" +class BERTLayerNorm(nn.Module): + def __init__(self): + tf.contrib.layers.layer_norm( + inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) + class BERTEmbeddings(nn.Module): + def __init__(self, embedding_size, vocab_size, + token_type_vocab_size, max_position_embeddings, + config): + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size) + + self.LayerNorm = BERTLayerNorm() # Not snake-cased to fit with TF model variable name + self.dropout = nn.dropout(config.hidden_dropout_prob) + + self.initialize_weights(self, config.initializer_range) + + def initialize_weights(self, initializer_range): + torch.truncated_normal_initializer(stddev=initializer_range) + + def forward(self, input_ids, token_type_ids=None): + batch_size = input_ids.size(0) + seq_length = input_ids.size(1) + position_ids = torch.range().view(batch_size, seq_length) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BERTIntermediate(nn.Module): + def __init__(self, config): + super(BERTOutput, self).__init__() + self.dense = nn.Linear() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + return hidden_states + + +class BERTOutput(nn.Module): + def __init__(self, config): + super(BERTOutput, self).__init__() + self.dense = nn.Linear() + self.LayerNorm = BERTLayerNorm(config) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BERTSelfAttention(nn.Module): + def __init__(self, config): + super(BERTSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + attention_head_size = int(config.hidden_size / config.num_attention_heads) + all_head_size = num_attention_heads * attention_head_size + + self.query = nn.Linear(config.hidden_size, all_head_size) + self.key = nn.Linear(config.hidden_size, all_head_size) + self.value = nn.Linear(config.hidden_size, all_head_size) + + def transpose_for_scores(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) + else: + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + return hidden_states + + +class BERTAttention(nn.Module): + def __init__(self, config): + super(BERTAttention, self).__init__() + self.self = BERTSelfAttention(config) + self.output = BERTOutput(config) + + def forward(self, hidden_states): + hidden_states = self.self(hidden_states) + hidden_states = self.output(hidden_states) + return hidden_states + + +class BERTLayer(nn.Module): + def __init__(self, config): + super(BERTLayer, self).__init__() + self.attention = BERTAttention(config) + self.intermediate = BERTIntermediate(config) + self.output = BERTOutput(config) + + def forward(self, hidden_states): + hidden_states = self.attention(hidden_states) + hidden_states = self.intermediate(hidden_states) + hidden_states = self.output(hidden_states) + return hidden_states + + +class BERTEncoder(nn.Module): + def __init__(self, config): + super(BERTEncoder, self).__init__() + layer = BERTLayer(n_ctx, cfg, scale=True) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states): + """ + Args: + hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size] + Return: + float Tensor of shape [batch_size, seq_length, hidden_size] + """ + for layer_module in self.layer: + hidden_states = layer_module(hidden_states) + return hidden_states class BertModel(nn.Module): @@ -132,28 +257,11 @@ class BertModel(nn.Module): ``` """ - def __init__(self, - config, - is_training, - input_ids, - input_mask=None, - token_type_ids=None, - use_one_hot_embeddings=True, - scope=None): + def __init__(self, config: BertConfig): """Constructor for BertModel. Args: config: `BertConfig` instance. - is_training: bool. rue for training model, false for eval model. Controls - whether dropout will be applied. - input_ids: int32 Tensor of shape [batch_size, seq_length]. - input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. - token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. - use_one_hot_embeddings: (optional) bool. Whether to use one-hot word - embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, - it is must faster if this is True, on the CPU or GPU, it is faster if - this is False. - scope: (optional) variable scope. Defaults to "bert". Raises: ValueError: The config is invalid or one of the input tensor shapes @@ -168,15 +276,20 @@ class BertModel(nn.Module): batch_size = input_ids.size(0) seq_length = input_ids.size(1) - if input_mask is None: - input_mask = torch.ones(batch_size, seq_length), dtype=torch.long) + self.embeddings = BERTEmbeddings(config) + self.encoder = BERTEncoder(config) - if token_type_ids is None: - token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long) - self.embeddings = BERTEmbeddings(config.vocab_size, config.hidden_size) - self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + def forward(self, input_ids, token_type_ids=None, input_mask=None): + if input_mask is None: + input_mask = torch.ones(batch_size, seq_length), dtype=torch.long) + + if token_type_ids is None: + token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long) + + hidden_states = self.embeddings(input_ids, token_type_ids, input_mask) + hidden_states = self.encoder(hidden_states) + # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( input_ids=input_ids,