From 603c513b35d1daf623e48eb68d54e06502d5e17d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 25 Jun 2019 10:45:07 +0200 Subject: [PATCH] update main conversion script and readme --- README.md | 24 +++- pytorch_pretrained_bert/__main__.py | 113 +++++++++++------- .../convert_xlnet_checkpoint_to_pytorch.py | 8 +- 3 files changed, 96 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 4f69cbdd19..e7834f8605 100644 --- a/README.md +++ b/README.md @@ -1690,7 +1690,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase ```shell export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 -pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \ +pytorch_pretrained_bert bert \ $BERT_BASE_DIR/bert_model.ckpt \ $BERT_BASE_DIR/bert_config.json \ $BERT_BASE_DIR/pytorch_model.bin @@ -1705,7 +1705,7 @@ Here is an example of the conversion process for a pre-trained OpenAI GPT model, ```shell export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights -pytorch_pretrained_bert convert_openai_checkpoint \ +pytorch_pretrained_bert gpt \ $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ $PYTORCH_DUMP_OUTPUT \ [OPENAI_GPT_CONFIG] @@ -1718,7 +1718,7 @@ Here is an example of the conversion process for a pre-trained Transformer-XL mo ```shell export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint -pytorch_pretrained_bert convert_transfo_xl_checkpoint \ +pytorch_pretrained_bert transfo_xl \ $TRANSFO_XL_CHECKPOINT_FOLDER_PATH \ $PYTORCH_DUMP_OUTPUT \ [TRANSFO_XL_CONFIG] @@ -1731,12 +1731,28 @@ Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 mo ```shell export GPT2_DIR=/path/to/gpt2/checkpoint -pytorch_pretrained_bert convert_gpt2_checkpoint \ +pytorch_pretrained_bert gpt2 \ $GPT2_DIR/model.ckpt \ $PYTORCH_DUMP_OUTPUT \ [GPT2_CONFIG] ``` +### XLNet + +Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script: + +```shell +export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint +export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config + +pytorch_pretrained_bert xlnet \ + $TRANSFO_XL_CHECKPOINT_PATH \ + $TRANSFO_XL_CONFIG_PATH \ + $PYTORCH_DUMP_OUTPUT \ + STS-B \ +``` + + ## TPU TPU support and pretraining scripts diff --git a/pytorch_pretrained_bert/__main__.py b/pytorch_pretrained_bert/__main__.py index a2aae9e9ce..bb9534a830 100644 --- a/pytorch_pretrained_bert/__main__.py +++ b/pytorch_pretrained_bert/__main__.py @@ -1,20 +1,16 @@ # coding: utf8 def main(): import sys - if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ - "convert_tf_checkpoint_to_pytorch", - "convert_openai_checkpoint", - "convert_transfo_xl_checkpoint", - "convert_gpt2_checkpoint", - ]: + if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]: print( "Should be used as one of: \n" - ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" - ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" - ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" - ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") + ">> `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" + ">> `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" + ">> `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" + ">> `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n" + ">> `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") else: - if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": + if sys.argv[1] == "bert": try: from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch except ImportError: @@ -25,24 +21,28 @@ def main(): if len(sys.argv) != 5: # pylint: disable=line-too-long - print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") + print("Should be used as `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") else: PYTORCH_DUMP_OUTPUT = sys.argv.pop() TF_CONFIG = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop() convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) - elif sys.argv[1] == "convert_openai_checkpoint": + elif sys.argv[1] == "gpt": from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch - OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] - PYTORCH_DUMP_OUTPUT = sys.argv[3] - if len(sys.argv) == 5: - OPENAI_GPT_CONFIG = sys.argv[4] + if len(sys.argv) < 4 or len(sys.argv) > 5: + # pylint: disable=line-too-long + print("Should be used as `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") else: - OPENAI_GPT_CONFIG = "" - convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, - OPENAI_GPT_CONFIG, - PYTORCH_DUMP_OUTPUT) - elif sys.argv[1] == "convert_transfo_xl_checkpoint": + OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] + PYTORCH_DUMP_OUTPUT = sys.argv[3] + if len(sys.argv) == 5: + OPENAI_GPT_CONFIG = sys.argv[4] + else: + OPENAI_GPT_CONFIG = "" + convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, + OPENAI_GPT_CONFIG, + PYTORCH_DUMP_OUTPUT) + elif sys.argv[1] == "transfo_xl": try: from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch except ImportError: @@ -50,20 +50,23 @@ def main(): "In that case, it requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions.") raise - - if 'ckpt' in sys.argv[2].lower(): - TF_CHECKPOINT = sys.argv[2] - TF_DATASET_FILE = "" + if len(sys.argv) < 4 or len(sys.argv) > 5: + # pylint: disable=line-too-long + print("Should be used as `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") else: - TF_DATASET_FILE = sys.argv[2] - TF_CHECKPOINT = "" - PYTORCH_DUMP_OUTPUT = sys.argv[3] - if len(sys.argv) == 5: - TF_CONFIG = sys.argv[4] - else: - TF_CONFIG = "" - convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) - else: + if 'ckpt' in sys.argv[2].lower(): + TF_CHECKPOINT = sys.argv[2] + TF_DATASET_FILE = "" + else: + TF_DATASET_FILE = sys.argv[2] + TF_CHECKPOINT = "" + PYTORCH_DUMP_OUTPUT = sys.argv[3] + if len(sys.argv) == 5: + TF_CONFIG = sys.argv[4] + else: + TF_CONFIG = "" + convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) + elif sys.argv[1] == "gpt2": try: from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch except ImportError: @@ -72,12 +75,40 @@ def main(): "https://www.tensorflow.org/install/ for installation instructions.") raise - TF_CHECKPOINT = sys.argv[2] - PYTORCH_DUMP_OUTPUT = sys.argv[3] - if len(sys.argv) == 5: - TF_CONFIG = sys.argv[4] + if len(sys.argv) < 4 or len(sys.argv) > 5: + # pylint: disable=line-too-long + print("Should be used as `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") else: - TF_CONFIG = "" - convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) + TF_CHECKPOINT = sys.argv[2] + PYTORCH_DUMP_OUTPUT = sys.argv[3] + if len(sys.argv) == 5: + TF_CONFIG = sys.argv[4] + else: + TF_CONFIG = "" + convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) + else: + try: + from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch + except ImportError: + print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " + "In that case, it requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + + if len(sys.argv) < 5 or len(sys.argv) > 6: + # pylint: disable=line-too-long + print("Should be used as `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") + else: + TF_CHECKPOINT = sys.argv[2] + TF_CONFIG = sys.argv[3] + PYTORCH_DUMP_OUTPUT = sys.argv[4] + if len(sys.argv) == 6: + FINETUNING_TASK = sys.argv[5] + + convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, + TF_CONFIG, + PYTORCH_DUMP_OUTPUT, + FINETUNING_TASK) + if __name__ == '__main__': main() diff --git a/pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py index e56cb538f4..d343fd2189 100755 --- a/pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py @@ -70,7 +70,7 @@ if __name__ == "__main__": required = True, help = "The config json file corresponding to the pre-trained XLNet model. \n" "This specifies the model architecture.") - parser.add_argument("--pytorch_dump_folder_path",finetuning_task + parser.add_argument("--pytorch_dump_folder_path", default = None, type = str, required = True, @@ -81,6 +81,6 @@ if __name__ == "__main__": help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") args = parser.parse_args() convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, - args.xlnet_config_file, - args.pytorch_dump_folder_path, - args.finetuning_task) + args.xlnet_config_file, + args.pytorch_dump_folder_path, + args.finetuning_task)