diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index 43c57c9cd3..afcdf1e64e 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -102,7 +102,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, errors='replace', + def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs) @@ -177,9 +177,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): def _convert_token_to_id(self, token): """ Converts a token (str/unicode) in an id using the vocab. """ - if token in self.encoder: - return self.encoder.get(token) - return self.encoder.get(self.unk_token) + return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (string/unicode) using the vocab."""