From d7d36181fdefdabadc53adf51bed4a2680f5880a Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Nov 2019 17:49:01 +0000 Subject: [PATCH] GPT-2 XL --- docs/source/pretrained_models.rst | 3 +++ transformers/configuration_gpt2.py | 1 + transformers/modeling_gpt2.py | 1 + transformers/tokenization_gpt2.py | 3 +++ 4 files changed, 8 insertions(+) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 43c08228bd..559c81cbb0 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -73,6 +73,9 @@ Here is the full list of the currently provided pretrained models together with | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. | | | | | OpenAI's Large-sized GPT-2 English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``gpt2-xl`` | | 48-layer, 1600-hidden, 25-heads, 1558M parameters. | +| | | | OpenAI's XL-sized GPT-2 English model | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. | | | | | English model trained on wikitext-103 | diff --git a/transformers/configuration_gpt2.py b/transformers/configuration_gpt2.py index e7d853f317..c2fb4948d3 100644 --- a/transformers/configuration_gpt2.py +++ b/transformers/configuration_gpt2.py @@ -29,6 +29,7 @@ logger = logging.getLogger(__name__) GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", + "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} class GPT2Config(PretrainedConfig): diff --git a/transformers/modeling_gpt2.py b/transformers/modeling_gpt2.py index 6812b235f5..e3d26797c8 100644 --- a/transformers/modeling_gpt2.py +++ b/transformers/modeling_gpt2.py @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin", + "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",} def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): diff --git a/transformers/tokenization_gpt2.py b/transformers/tokenization_gpt2.py index 6a7f75acb2..4bec515903 100644 --- a/transformers/tokenization_gpt2.py +++ b/transformers/tokenization_gpt2.py @@ -46,6 +46,7 @@ PRETRAINED_VOCAB_FILES_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", + 'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json", 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json", }, 'merges_file': @@ -53,6 +54,7 @@ PRETRAINED_VOCAB_FILES_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", + 'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt", 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt", }, } @@ -61,6 +63,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'gpt2': 1024, 'gpt2-medium': 1024, 'gpt2-large': 1024, + 'gpt2-xl': 1024, 'distilgpt2': 1024, }