diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 37c5a2d9fb..1c579de83c 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -115,6 +115,7 @@ class GPT2Config(object): n_head=12, layer_norm_epsilon=1e-5, initializer_range=0.02, + predict_special_tokens=True ): """Constructs GPT2Config. @@ -130,6 +131,7 @@ class GPT2Config(object): layer_norm_epsilon: epsilon to use in the layer norm layers initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. + predict_special_tokens: should we predict special tokens (when the model has a LM head) """ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): @@ -147,6 +149,7 @@ class GPT2Config(object): self.n_head = n_head self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range + self.predict_special_tokens = predict_special_tokens else: raise ValueError( "First argument must be either a vocabulary size (int)" @@ -297,18 +300,20 @@ class GPT2LMHead(nn.Module): def __init__(self, model_embeddings_weights, config): super(GPT2LMHead, self).__init__() self.n_embd = config.n_embd + self.vocab_size = config.vocab_size + self.predict_special_tokens = config.predict_special_tokens embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.set_embeddings_weights(model_embeddings_weights) - def set_embeddings_weights(self, model_embeddings_weights): - embed_shape = model_embeddings_weights.shape + def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): + self.predict_special_tokens = predict_special_tokens self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): - # Truncated Language modeling logits (we remove the last token) - # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) lm_logits = self.decoder(hidden_state) + if not self.predict_special_tokens: + lm_logits = lm_logits[..., :self.vocab_size] return lm_logits @@ -353,9 +358,6 @@ class GPT2PreTrainedModel(nn.Module): ) self.config = config - def set_num_special_tokens(self, num_special_tokens): - pass - def init_weights(self, module): """ Initialize the weights. """ @@ -650,12 +652,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.apply(self.init_weights) - def set_num_special_tokens(self, num_special_tokens): + def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): """ Update input and output embeddings with new embedding matrice Make sure we are sharing the embeddings """ + self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens self.transformer.set_num_special_tokens(num_special_tokens) - self.lm_head.set_embeddings_weights(self.transformer.wte.weight) + self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) @@ -729,12 +732,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.apply(self.init_weights) - def set_num_special_tokens(self, num_special_tokens): + def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): """ Update input and output embeddings with new embedding matrice Make sure we are sharing the embeddings """ + self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens self.transformer.set_num_special_tokens(num_special_tokens) - self.lm_head.set_embeddings_weights(self.transformer.wte.weight) + self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index 8ffd7a68e2..c18589b7b0 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -263,8 +263,8 @@ class GPT2Tokenizer(object): def encode(self, text): return self.convert_tokens_to_ids(self.tokenize(text)) - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) + def decode(self, tokens, skip_special_tokens=False): + text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text