From e3fb4310d6ce63ea55e814544fa47207bc3f72f9 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 11 Jul 2019 18:44:29 -0400 Subject: [PATCH] From pretrained correct initialization. Unknown token handling for gpt2. --- pytorch_transformers/modeling_gpt2.py | 2 +- pytorch_transformers/modeling_openai.py | 2 +- pytorch_transformers/tokenization_gpt2.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 495e002529..2b8ec88a50 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -423,7 +423,7 @@ class GPT2PreTrainedModel(PreTrainedModel): """ num_special_tokens = kwargs.pop('num_special_tokens', None) - model = super().from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + model = super(GPT2PreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) # Add additional embeddings for special tokens if needed # This step also make sure we are still sharing the output and input embeddings after loading weights diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index aa35b163f1..de7ab8ae4b 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -431,7 +431,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): num_special_tokens = kwargs.get('num_special_tokens', None) kwargs.pop('num_special_tokens', None) - model = super(PreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, pretrained_model_name_or_path, *inputs, **kwargs) + model = super(OpenAIGPTPreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, pretrained_model_name_or_path, *inputs, **kwargs) # Add additional embeddings for special tokens if needed # This step also make sure we are still sharing the output and input embeddings after loading weights diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index af1ad2cf8f..6084dc3e05 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -177,11 +177,11 @@ class GPT2Tokenizer(PreTrainedTokenizer): def _convert_token_to_id(self, token): """ Converts a token (str/unicode) in an id using the vocab. """ - return self.encoder.get(token, self.encoder.get(self.unk_token)) + return self.encoder.get(token) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (string/unicode) using the vocab.""" - return self.decoder.get(index, self.unk_token) + return self.decoder.get(index) def _convert_ids_to_string(self, tokens_ids): """Converts a sequence of ids in a string."""