From a7e01a248b4232fa1bfa62c3a0d86d3d09efb281 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 24 Sep 2019 10:58:52 +0200 Subject: [PATCH] converting distilled/fine-tuned models --- pytorch_transformers/__init__.py | 2 +- .../convert_pytorch_checkpoint_to_tf2.py | 103 ++++++++++++------ .../modeling_tf_distilbert.py | 6 +- .../modeling_tf_pytorch_utils.py | 18 ++- 4 files changed, 90 insertions(+), 39 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index e12f4b2650..907115f70d 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -146,7 +146,7 @@ if _tf_available: from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer, TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, - TFDistilBertForSequenceClassification, + TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) diff --git a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py index 567f1d0b5b..6c0043d6a7 100644 --- a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py @@ -24,43 +24,43 @@ import tensorflow as tf from pytorch_transformers import is_torch_available, cached_path -from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, +from pytorch_transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, - RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) + RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) if is_torch_available(): import torch import numpy as np - from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + from pytorch_transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, - RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, + DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) else: - (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, - RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = ( + RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, + DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = ( + None, None, None, None, None, None, None, None, None, None, None, None, None, None, - None, None, - None, None, - None, None,) + None, None, None, + None, None, None,) import logging @@ -68,22 +68,29 @@ logging.basicConfig(level=logging.INFO) MODEL_CLASSES = { 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), } -def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False): +def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): if model_type not in MODEL_CLASSES: raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] # Initialise TF model + if config_file in aws_config_map: + config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models) config = config_class.from_json_file(config_file) config.output_hidden_states = True config.output_attentions = True @@ -91,6 +98,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file tf_model = model_class(config) # Load weights from tf checkpoint + if pytorch_checkpoint_path in aws_model_maps: + pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) tf_model = loading_fct(tf_model, pytorch_checkpoint_path) if compare_with_pt_model: @@ -117,7 +126,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file tf_model.save_weights(tf_dump_path, save_format='h5') -def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False): +def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, + compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False): assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" if args_model_type is None: @@ -134,20 +144,39 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] - for i, shortcut_name in enumerate(aws_config_map.keys(), start=1): + if model_shortcut_names_or_path is None: + model_shortcut_names_or_path = list(aws_model_maps.keys()) + if config_shortcut_names_or_path is None: + config_shortcut_names_or_path = model_shortcut_names_or_path + + for i, (model_shortcut_name, config_shortcut_name) in enumerate( + zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1): print("-" * 100) - print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name)) - print("-" * 100) - if 'finetuned' in shortcut_name: - print(" Skipping finetuned checkpoint ") + if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name: + if not only_convert_finetuned_models: + print(" Skipping finetuned checkpoint {}".format(model_shortcut_name)) + continue + model_type = model_shortcut_name + elif only_convert_finetuned_models: + print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name)) continue - config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models) - model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models) + print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type)) + print("-" * 100) + + if config_shortcut_name in aws_config_map: + config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models) + else: + config_file = cached_path(config_shortcut_name, force_download=not use_cached_models) + + if model_shortcut_name in aws_model_maps: + model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models) + else: + model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) convert_pt_checkpoint_to_tf(model_type, model_file, config_file, - os.path.join(tf_dump_path, shortcut_name + '-tf_model.h5'), + os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), compare_with_pt_model=compare_with_pt_model) os.remove(config_file) os.remove(model_file) @@ -176,23 +205,29 @@ if __name__ == "__main__": help = "The config json file corresponding to the pre-trained model. \n" "This specifies the model architecture. If not given and " "--pytorch_checkpoint_path is not given or is a shortcut name" - "use the configuration associated to teh shortcut name on the AWS") + "use the configuration associated to the shortcut name on the AWS") parser.add_argument("--compare_with_pt_model", action='store_true', help = "Compare Tensorflow and PyTorch model predictions.") parser.add_argument("--use_cached_models", action='store_true', help = "Use cached models if possible instead of updating to latest checkpoint versions.") + parser.add_argument("--only_convert_finetuned_models", + action='store_true', + help = "Only convert finetuned models.") args = parser.parse_args() - if args.pytorch_checkpoint_path is not None: - convert_pt_checkpoint_to_tf(args.model_type.lower(), - args.pytorch_checkpoint_path, - args.config_file, - args.tf_dump_path, - compare_with_pt_model=args.compare_with_pt_model) - else: - convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, - args.tf_dump_path, - compare_with_pt_model=args.compare_with_pt_model, - use_cached_models=args.use_cached_models) + # if args.pytorch_checkpoint_path is not None: + # convert_pt_checkpoint_to_tf(args.model_type.lower(), + # args.pytorch_checkpoint_path, + # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path, + # args.tf_dump_path, + # compare_with_pt_model=args.compare_with_pt_model, + # use_cached_models=args.use_cached_models) + # else: + convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, + args.tf_dump_path, + model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, + compare_with_pt_model=args.compare_with_pt_model, + use_cached_models=args.use_cached_models, + only_convert_finetuned_models=args.only_convert_finetuned_models) diff --git a/pytorch_transformers/modeling_tf_distilbert.py b/pytorch_transformers/modeling_tf_distilbert.py index 706d2fc02d..1811573bbf 100644 --- a/pytorch_transformers/modeling_tf_distilbert.py +++ b/pytorch_transformers/modeling_tf_distilbert.py @@ -653,7 +653,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel): super(TFDistilBertForSequenceClassification, self).__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels - self.distilbert = TFDistilBertModel(config, name="distilbert") + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") self.pre_classifier = tf.keras.layers.Dense(config.dim, activation='relu', name="pre_classifier") self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier") self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout) @@ -714,8 +714,8 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel): def __init__(self, config, *inputs, **kwargs): super(TFDistilBertForQuestionAnswering, self).__init__(config, *inputs, **kwargs) - self.distilbert = TFDistilBertModel(config, name="distilbert") - self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_output') + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') assert config.num_labels == 2 self.dropout = tf.keras.layers.Dropout(config.qa_dropout) diff --git a/pytorch_transformers/modeling_tf_pytorch_utils.py b/pytorch_transformers/modeling_tf_pytorch_utils.py index 9950a5a73f..12e5023802 100644 --- a/pytorch_transformers/modeling_tf_pytorch_utils.py +++ b/pytorch_transformers/modeling_tf_pytorch_utils.py @@ -148,8 +148,24 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path): """ Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). + Conventions for TF2.0 scopes -> PyTorch attribute names conversions: + - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) """ - raise NotImplementedError + try: + import tensorflow as tf + import torch + except ImportError as e: + logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") + raise e + + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Loading TensorFlow weights from {}".format(tf_path)) + + tf_state_dict = torch.load(tf_path, map_location='cpu') + + return load_tf2_weights_in_pytorch_model(pt_model, tf_state_dict) def load_tf2_weights_in_pytorch_model(pt_model, tf_model): """ Load TF2.0 symbolic weights in a PyTorch model