From 366a3b02857a1fdae447358cc76bf8abf1bf11eb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 8 May 2019 21:43:51 +0200 Subject: [PATCH] clean up in tokenization --- pytorch_pretrained_bert/modeling_gpt2.py | 6 ++++-- pytorch_pretrained_bert/tokenization_gpt2.py | 9 ++++++++- pytorch_pretrained_bert/tokenization_openai.py | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 7623e4ddad..0554442b7f 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} -PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} +PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", + "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} +PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", + "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): """ Load tf checkpoints in a pytorch model diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index c18589b7b0..c66af3ff13 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -37,9 +37,11 @@ logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", + 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", } PRETRAINED_MERGES_ARCHIVE_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", + 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'gpt2': 1024, @@ -263,9 +265,14 @@ class GPT2Tokenizer(object): def encode(self, text): return self.convert_tokens_to_ids(self.tokenize(text)) - def decode(self, tokens, skip_special_tokens=False): + def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True): text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + if clean_up_tokenization_spaces: + text = text.replace('', '') + text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' + ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" + ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") return text def save_vocabulary(self, vocab_path): diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 214a476ce9..c68e247e1e 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object): out_string = ''.join(tokens).replace('', ' ').strip() if clean_up_tokenization_spaces: out_string = out_string.replace('', '') - out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' + out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") return out_string