From 6be46a6e6422d7ab34984d6fbcaadad0323e5349 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Thu, 3 Oct 2019 10:07:18 -0400 Subject: [PATCH] update links to new weights --- transformers/configuration_gpt2.py | 3 ++- transformers/modeling_gpt2.py | 3 ++- transformers/modeling_tf_gpt2.py | 3 ++- transformers/tokenization_gpt2.py | 3 +++ 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/transformers/configuration_gpt2.py b/transformers/configuration_gpt2.py index c83d9e82ce..e7d853f317 100644 --- a/transformers/configuration_gpt2.py +++ b/transformers/configuration_gpt2.py @@ -28,7 +28,8 @@ logger = logging.getLogger(__name__) GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", - "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", + "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} class GPT2Config(PretrainedConfig): """Configuration class to store the configuration of a `GPT2Model`. diff --git a/transformers/modeling_gpt2.py b/transformers/modeling_gpt2.py index bc85224022..891dfc5677 100644 --- a/transformers/modeling_gpt2.py +++ b/transformers/modeling_gpt2.py @@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", - "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"} + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin", + "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",} def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): """ Load tf checkpoints in a pytorch model diff --git a/transformers/modeling_tf_gpt2.py b/transformers/modeling_tf_gpt2.py index e958c2cbf1..883340cac9 100644 --- a/transformers/modeling_tf_gpt2.py +++ b/transformers/modeling_tf_gpt2.py @@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5", - "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"} + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5", + "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",} def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): diff --git a/transformers/tokenization_gpt2.py b/transformers/tokenization_gpt2.py index 3e931dfcf8..6a7f75acb2 100644 --- a/transformers/tokenization_gpt2.py +++ b/transformers/tokenization_gpt2.py @@ -46,12 +46,14 @@ PRETRAINED_VOCAB_FILES_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", + 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json", }, 'merges_file': { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", + 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt", }, } @@ -59,6 +61,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'gpt2': 1024, 'gpt2-medium': 1024, 'gpt2-large': 1024, + 'distilgpt2': 1024, } @lru_cache()