diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 77e9cda349..be33eda1c6 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -143,6 +143,7 @@ class OpenAIGPTConfig(object): attn_pdrop=0.1, layer_norm_epsilon=1e-5, initializer_range=0.02, + predict_special_tokens=True ): """Constructs OpenAIGPTConfig. @@ -165,6 +166,7 @@ class OpenAIGPTConfig(object): layer_norm_epsilon: epsilon to use in the layer norm layers initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. + predict_special_tokens: should we predict special tokens (when the model has a LM head) """ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): @@ -186,6 +188,7 @@ class OpenAIGPTConfig(object): self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range + self.predict_special_tokens = predict_special_tokens else: raise ValueError( "First argument must be either a vocabulary size (int)" @@ -356,18 +359,21 @@ class OpenAIGPTLMHead(nn.Module): def __init__(self, model_embeddings_weights, config): super(OpenAIGPTLMHead, self).__init__() self.n_embd = config.n_embd + self.vocab_size = config.vocab_size + self.predict_special_tokens = config.predict_special_tokens embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.set_embeddings_weights(model_embeddings_weights) - def set_embeddings_weights(self, model_embeddings_weights): + def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): + self.predict_special_tokens = predict_special_tokens embed_shape = model_embeddings_weights.shape self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): - # Truncated Language modeling logits (we remove the last token) - # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) lm_logits = self.decoder(hidden_state) + if not self.predict_special_tokens: + lm_logits = lm_logits[..., :self.vocab_size] return lm_logits @@ -428,9 +434,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def set_num_special_tokens(self, num_special_tokens): - pass - @classmethod def from_pretrained( cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs @@ -613,7 +616,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.apply(self.init_weights) - # nn.init.normal_(self.embed.weight, std=0.02) def set_num_special_tokens(self, num_special_tokens): " Update input embeddings with new embedding matrice if needed " @@ -727,12 +729,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.apply(self.init_weights) - def set_num_special_tokens(self, num_special_tokens): + def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): """ Update input and output embeddings with new embedding matrice Make sure we are sharing the embeddings """ + self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens self.transformer.set_num_special_tokens(num_special_tokens) - self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) + self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) @@ -821,12 +824,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.apply(self.init_weights) - def set_num_special_tokens(self, num_special_tokens): + def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): """ Update input and output embeddings with new embedding matrice Make sure we are sharing the embeddings """ + self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens self.transformer.set_num_special_tokens(num_special_tokens) - self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) + self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids)