From c37815f1300519e1a812e1080c46641db6f9f604 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 20 Dec 2019 14:35:40 +0100 Subject: [PATCH] clean up PT <=> TF 2.0 conversion and config loading --- .../convert_pytorch_checkpoint_to_tf2.py | 9 +++++---- transformers/modeling_tf_utils.py | 17 ++++++++++++----- transformers/modeling_utils.py | 17 ++++++++++++----- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index 4a9832f123..0edac6fb7d 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model, TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP) @@ -47,7 +47,7 @@ if is_torch_available(): TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, + DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) @@ -59,7 +59,7 @@ else: TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, + DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = ( @@ -70,7 +70,7 @@ else: None, None, None, None, None, None, None, - None, None, None, + None, None, None, None, None, None, None, None, None, None) @@ -93,6 +93,7 @@ MODEL_CLASSES = { 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP), diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 401ffeb67e..0aa65a9f17 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -184,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model): model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method - config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + config: (`optional`) one of: + - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or + - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or @@ -236,10 +238,11 @@ class TFPreTrainedModel(tf.keras.Model): proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) - # Load config - if config is None: + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( - pretrained_model_name_or_path, *model_args, + config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, @@ -310,7 +313,11 @@ class TFPreTrainedModel(tf.keras.Model): assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 - model.load_weights(resolved_archive_file, by_name=True) + try: + model.load_weights(resolved_archive_file, by_name=True) + except OSError: + raise OSError("Unable to load weights from h5 file. " + "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. ") ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index eff54f71e1..3bc407e4a3 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -281,7 +281,9 @@ class PreTrainedModel(nn.Module): model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method - config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + config: (`optional`) one of: + - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or + - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or @@ -336,10 +338,11 @@ class PreTrainedModel(nn.Module): proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) - # Load config - if config is None: + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( - pretrained_model_name_or_path, *model_args, + config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, @@ -408,7 +411,11 @@ class PreTrainedModel(nn.Module): model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu') + try: + state_dict = torch.load(resolved_archive_file, map_location='cpu') + except: + raise OSError("Unable to load weights from pytorch checkpoint file. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ") missing_keys = [] unexpected_keys = []