From 1b35d05d4b3c121a9740544aa6f884f1039780b1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 16 Jul 2019 09:41:55 +0200 Subject: [PATCH] update conversion scripts and __main__ --- pytorch_transformers/__main__.py | 28 ++++++++++++++----- .../convert_gpt2_checkpoint_to_pytorch.py | 5 +++- .../convert_openai_checkpoint_to_pytorch.py | 5 +++- .../convert_tf_checkpoint_to_pytorch.py | 9 +++--- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 3 ++ .../convert_xlm_checkpoint_to_pytorch.py | 3 +- .../convert_xlnet_checkpoint_to_pytorch.py | 9 ++++-- pytorch_transformers/modeling_xlnet.py | 2 ++ .../tokenization_transfo_xl.py | 2 +- pytorch_transformers/tokenization_utils.py | 3 +- pytorch_transformers/tokenization_xlnet.py | 4 ++- 11 files changed, 53 insertions(+), 20 deletions(-) diff --git a/pytorch_transformers/__main__.py b/pytorch_transformers/__main__.py index 95504c1493..b047fa7447 100644 --- a/pytorch_transformers/__main__.py +++ b/pytorch_transformers/__main__.py @@ -1,14 +1,15 @@ # coding: utf8 def main(): import sys - if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]: + if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: print( "Should be used as one of: \n" - ">> `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" - ">> `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" - ">> `pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" - ">> `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n" - ">> `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") + ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" + ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" + ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" + ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" + ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" + ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") else: if sys.argv[1] == "bert": try: @@ -86,7 +87,7 @@ def main(): else: TF_CONFIG = "" convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) - else: + elif sys.argv[1] == "xlnet": try: from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch except ImportError: @@ -104,11 +105,24 @@ def main(): PYTORCH_DUMP_OUTPUT = sys.argv[4] if len(sys.argv) == 6: FINETUNING_TASK = sys.argv[5] + else: + FINETUNING_TASK = None convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, FINETUNING_TASK) + elif sys.argv[1] == "xlm": + from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch + + if len(sys.argv) != 4: + # pylint: disable=line-too-long + print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") + else: + XLM_CHECKPOINT_PATH = sys.argv[2] + PYTORCH_DUMP_OUTPUT = sys.argv[3] + + convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) if __name__ == '__main__': main() diff --git a/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py b/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py index 86c8264cb5..68cb798a7d 100755 --- a/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py @@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, GPT2Model, load_tf_weights_in_gpt2) +import logging +logging.basicConfig(level=logging.INFO) + def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): # Construct model @@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p model = GPT2Model(config) # Load weights from numpy - load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) + load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME diff --git a/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py b/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py index 68e9dea624..8ec852a4bd 100755 --- a/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py @@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTModel, load_tf_weights_in_openai_gpt) +import logging +logging.basicConfig(level=logging.INFO) + def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): # Construct model @@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c model = OpenAIGPTModel(config) # Load weights from numpy - load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) + load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME diff --git a/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py b/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py index 7530d7e12d..9f121e8b79 100755 --- a/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py @@ -18,15 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import re import argparse -import tensorflow as tf import torch -import numpy as np from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert +import logging +logging.basicConfig(level=logging.INFO) + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) @@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor model = BertForPreTraining(config) # Load weights from tf checkpoint - load_tf_weights_in_bert(model, tf_checkpoint_path) + load_tf_weights_in_bert(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) diff --git a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py index db23e5bffe..b6672aedf7 100755 --- a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py @@ -36,6 +36,9 @@ if sys.version_info[0] == 2: else: import pickle +import logging +logging.basicConfig(level=logging.INFO) + # We do this to be able to load python 2 datasets pickles # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 data_utils.Vocab = data_utils.TransfoXLTokenizer diff --git a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py index 416f1bc16d..8825f3c0dc 100755 --- a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py @@ -24,9 +24,10 @@ import torch import numpy from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME -from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel) from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES +import logging +logging.basicConfig(level=logging.INFO) def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): # Load checkpoint diff --git a/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py b/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py index f41db87124..834b47484f 100755 --- a/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py @@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = { "wnli": 2, } +import logging +logging.basicConfig(level=logging.INFO) def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): # Initialise PyTorch model @@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" if finetuning_task in GLUE_TASKS_NUM_LABELS: print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) - model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task]) + config.finetuning_task = finetuning_task + config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] + model = XLNetForSequenceClassification(config) elif 'squad' in finetuning_task: + config.finetuning_task = finetuning_task model = XLNetForQuestionAnswering(config) else: model = XLNetLMHeadModel(config) # Load weights from tf checkpoint - load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task) + load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) # Save pytorch-model pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index c50d0a3f48..855bce7dfe 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra logger = logging.getLogger(__name__) XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { + 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", } XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", } diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index b08e8e1cca..9406d48c7b 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = { } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - 'transfo-xl-wt103': 512, + 'transfo-xl-wt103': None, } PRETRAINED_CORPUS_ARCHIVE_MAP = { diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index d857e6f2d4..df18f5e536 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -208,7 +208,8 @@ class PreTrainedTokenizer(object): # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + if max_len is not None and isinstance(max_len, (int, float)): + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Merge resolved_vocab_files arguments in kwargs. added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index fa60a18d8a..a4f3fdfde2 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} PRETRAINED_VOCAB_FILES_MAP = { 'vocab_file': { + 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - 'xlnet-large-cased': 512, + 'xlnet-base-cased': None, + 'xlnet-large-cased': None, } SPIECE_UNDERLINE = u'▁'