adding gpt-2 large
This commit is contained in:
@@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
if gpt2_config_file == "":
|
if gpt2_config_file == "":
|
||||||
config = GPT2Config()
|
config = GPT2Config()
|
||||||
else:
|
else:
|
||||||
config = GPT2Config(gpt2_config_file)
|
config = GPT2Config.from_json_file(gpt2_config_file)
|
||||||
model = GPT2Model(config)
|
model = GPT2Model(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
if openai_config_file == "":
|
if openai_config_file == "":
|
||||||
config = OpenAIGPTConfig()
|
config = OpenAIGPTConfig()
|
||||||
else:
|
else:
|
||||||
config = OpenAIGPTConfig(openai_config_file)
|
config = OpenAIGPTConfig.from_json_file(openai_config_file)
|
||||||
model = OpenAIGPTModel(config)
|
model = OpenAIGPTModel(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
if transfo_xl_config_file == "":
|
if transfo_xl_config_file == "":
|
||||||
config = TransfoXLConfig()
|
config = TransfoXLConfig()
|
||||||
else:
|
else:
|
||||||
config = TransfoXLConfig(transfo_xl_config_file)
|
config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
model = TransfoXLLMHeadModel(config)
|
model = TransfoXLLMHeadModel(config)
|
||||||
|
|
||||||
|
|||||||
@@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
||||||
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||||
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
""" Load tf checkpoints in a pytorch model
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
|||||||
@@ -45,11 +45,13 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
{
|
{
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
||||||
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
|
||||||
},
|
},
|
||||||
'merges_file':
|
'merges_file':
|
||||||
{
|
{
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||||
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user