update conversion scripts and __main__
This commit is contained in:
@@ -1,14 +1,15 @@
|
|||||||
# coding: utf8
|
# coding: utf8
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]:
|
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
|
||||||
print(
|
print(
|
||||||
"Should be used as one of: \n"
|
"Should be used as one of: \n"
|
||||||
">> `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
|
||||||
">> `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
|
">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
|
||||||
">> `pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
|
">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
|
||||||
">> `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n"
|
">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
|
||||||
">> `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
|
">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
|
||||||
|
">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
|
||||||
else:
|
else:
|
||||||
if sys.argv[1] == "bert":
|
if sys.argv[1] == "bert":
|
||||||
try:
|
try:
|
||||||
@@ -86,7 +87,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
TF_CONFIG = ""
|
TF_CONFIG = ""
|
||||||
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||||
else:
|
elif sys.argv[1] == "xlnet":
|
||||||
try:
|
try:
|
||||||
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -104,11 +105,24 @@ def main():
|
|||||||
PYTORCH_DUMP_OUTPUT = sys.argv[4]
|
PYTORCH_DUMP_OUTPUT = sys.argv[4]
|
||||||
if len(sys.argv) == 6:
|
if len(sys.argv) == 6:
|
||||||
FINETUNING_TASK = sys.argv[5]
|
FINETUNING_TASK = sys.argv[5]
|
||||||
|
else:
|
||||||
|
FINETUNING_TASK = None
|
||||||
|
|
||||||
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
|
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
|
||||||
TF_CONFIG,
|
TF_CONFIG,
|
||||||
PYTORCH_DUMP_OUTPUT,
|
PYTORCH_DUMP_OUTPUT,
|
||||||
FINETUNING_TASK)
|
FINETUNING_TASK)
|
||||||
|
elif sys.argv[1] == "xlm":
|
||||||
|
from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
|
||||||
|
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
|
||||||
|
else:
|
||||||
|
XLM_CHECKPOINT_PATH = sys.argv[2]
|
||||||
|
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||||
|
|
||||||
|
convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
|
|||||||
GPT2Model,
|
GPT2Model,
|
||||||
load_tf_weights_in_gpt2)
|
load_tf_weights_in_gpt2)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
||||||
# Construct model
|
# Construct model
|
||||||
@@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
model = GPT2Model(config)
|
model = GPT2Model(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
|
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
|
|||||||
OpenAIGPTModel,
|
OpenAIGPTModel,
|
||||||
load_tf_weights_in_openai_gpt)
|
load_tf_weights_in_openai_gpt)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||||
# Construct model
|
# Construct model
|
||||||
@@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
model = OpenAIGPTModel(config)
|
model = OpenAIGPTModel(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
|
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
|||||||
@@ -18,15 +18,14 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
import tensorflow as tf
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = BertConfig.from_json_file(bert_config_file)
|
config = BertConfig.from_json_file(bert_config_file)
|
||||||
@@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
load_tf_weights_in_bert(model, tf_checkpoint_path)
|
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ if sys.version_info[0] == 2:
|
|||||||
else:
|
else:
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# We do this to be able to load python 2 datasets pickles
|
# We do this to be able to load python 2 datasets pickles
|
||||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import torch
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel)
|
|
||||||
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||||
# Load checkpoint
|
# Load checkpoint
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = {
|
|||||||
"wnli": 2,
|
"wnli": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
|
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
@@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
|||||||
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
|
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
|
||||||
if finetuning_task in GLUE_TASKS_NUM_LABELS:
|
if finetuning_task in GLUE_TASKS_NUM_LABELS:
|
||||||
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
|
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
|
||||||
model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task])
|
config.finetuning_task = finetuning_task
|
||||||
|
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
|
||||||
|
model = XLNetForSequenceClassification(config)
|
||||||
elif 'squad' in finetuning_task:
|
elif 'squad' in finetuning_task:
|
||||||
|
config.finetuning_task = finetuning_task
|
||||||
model = XLNetForQuestionAnswering(config)
|
model = XLNetForQuestionAnswering(config)
|
||||||
else:
|
else:
|
||||||
model = XLNetLMHeadModel(config)
|
model = XLNetLMHeadModel(config)
|
||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task)
|
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||||
|
|||||||
@@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
|
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin",
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
|
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'transfo-xl-wt103': 512,
|
'transfo-xl-wt103': None,
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -208,7 +208,8 @@ class PreTrainedTokenizer(object):
|
|||||||
# if we're using a pretrained model, ensure the tokenizer
|
# if we're using a pretrained model, ensure the tokenizer
|
||||||
# wont index sequences longer than the number of positional embeddings
|
# wont index sequences longer than the number of positional embeddings
|
||||||
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
||||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
if max_len is not None and isinstance(max_len, (int, float)):
|
||||||
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
|
|
||||||
# Merge resolved_vocab_files arguments in kwargs.
|
# Merge resolved_vocab_files arguments in kwargs.
|
||||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||||
|
|||||||
@@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
|
|||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
'vocab_file':
|
'vocab_file':
|
||||||
{
|
{
|
||||||
|
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
|
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'xlnet-large-cased': 512,
|
'xlnet-base-cased': None,
|
||||||
|
'xlnet-large-cased': None,
|
||||||
}
|
}
|
||||||
|
|
||||||
SPIECE_UNDERLINE = u'▁'
|
SPIECE_UNDERLINE = u'▁'
|
||||||
|
|||||||
Reference in New Issue
Block a user