GPT-2 XL
This commit is contained in:
@@ -73,6 +73,9 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. |
|
| | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. |
|
||||||
| | | | OpenAI's Large-sized GPT-2 English model |
|
| | | | OpenAI's Large-sized GPT-2 English model |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``gpt2-xl`` | | 48-layer, 1600-hidden, 25-heads, 1558M parameters. |
|
||||||
|
| | | | OpenAI's XL-sized GPT-2 English model |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
||||||
| | | | English model trained on wikitext-103 |
|
| | | | English model trained on wikitext-103 |
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
||||||
|
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
|
||||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
|
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
|
||||||
|
|
||||||
class GPT2Config(PretrainedConfig):
|
class GPT2Config(PretrainedConfig):
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ logger = logging.getLogger(__name__)
|
|||||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
||||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
|
||||||
|
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
|
||||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",}
|
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
||||||
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
|
||||||
|
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
|
||||||
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
|
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
|
||||||
},
|
},
|
||||||
'merges_file':
|
'merges_file':
|
||||||
@@ -53,6 +54,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||||
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
|
||||||
|
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
|
||||||
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
|
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -61,6 +63,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
'gpt2': 1024,
|
'gpt2': 1024,
|
||||||
'gpt2-medium': 1024,
|
'gpt2-medium': 1024,
|
||||||
'gpt2-large': 1024,
|
'gpt2-large': 1024,
|
||||||
|
'gpt2-xl': 1024,
|
||||||
'distilgpt2': 1024,
|
'distilgpt2': 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user