From aa05dc8935a3e5b349abecbdc5399796578fe965 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 21 Aug 2019 02:29:34 +0200 Subject: [PATCH] adding gpt-2 large --- pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py | 2 +- .../convert_openai_checkpoint_to_pytorch.py | 2 +- .../convert_transfo_xl_checkpoint_to_pytorch.py | 2 +- pytorch_transformers/modeling_gpt2.py | 6 ++++-- pytorch_transformers/tokenization_gpt2.py | 2 ++ 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py b/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py index f9e83f5d5b..e9bfa0302a 100755 --- a/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py @@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p if gpt2_config_file == "": config = GPT2Config() else: - config = GPT2Config(gpt2_config_file) + config = GPT2Config.from_json_file(gpt2_config_file) model = GPT2Model(config) # Load weights from numpy diff --git a/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py b/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py index 70895b4002..3009f8a99e 100755 --- a/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py @@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c if openai_config_file == "": config = OpenAIGPTConfig() else: - config = OpenAIGPTConfig(openai_config_file) + config = OpenAIGPTConfig.from_json_file(openai_config_file) model = OpenAIGPTModel(config) # Load weights from numpy diff --git a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py index 5733146444..7e79d58d7d 100755 --- a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py @@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, if transfo_xl_config_file == "": config = TransfoXLConfig() else: - config = TransfoXLConfig(transfo_xl_config_file) + config = TransfoXLConfig.from_json_file(transfo_xl_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = TransfoXLLMHeadModel(config) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index cb4b8cc4ab..9022048d6d 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", - "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} + "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"} GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", - "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} + "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): """ Load tf checkpoints in a pytorch model diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index 0aee856180..4016a85a7f 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -45,11 +45,13 @@ PRETRAINED_VOCAB_FILES_MAP = { { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", + 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", }, 'merges_file': { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", + 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", }, }