From da26bae61b8c1e741fdc6735d46c61b43f649561 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 10 Oct 2019 14:30:48 +0200 Subject: [PATCH] adding more tests on TF and pytorch serialization - updating configuration for better serialization --- transformers/__init__.py | 24 +++------ transformers/configuration_utils.py | 4 +- .../convert_pytorch_checkpoint_to_tf2.py | 54 ++++++++++--------- transformers/modeling_tf_bert.py | 10 ---- transformers/modeling_tf_ctrl.py | 10 ---- transformers/modeling_tf_distilbert.py | 10 ---- transformers/modeling_tf_gpt2.py | 10 ---- transformers/modeling_tf_openai.py | 10 ---- transformers/modeling_tf_pytorch_utils.py | 4 +- transformers/modeling_tf_roberta.py | 10 ---- transformers/modeling_tf_transfo_xl.py | 10 ---- transformers/modeling_tf_utils.py | 11 ++-- transformers/modeling_tf_xlm.py | 28 +++++----- transformers/modeling_tf_xlnet.py | 9 ---- transformers/tests/modeling_common_test.py | 34 ++++++++++++ 15 files changed, 90 insertions(+), 148 deletions(-) diff --git a/transformers/__init__.py b/transformers/__init__.py index 1edbf8b9a0..971b19b369 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -110,65 +110,55 @@ if is_tf_available(): TFBertForMaskedLM, TFBertForNextSentencePrediction, TFBertForSequenceClassification, TFBertForMultipleChoice, TFBertForTokenClassification, TFBertForQuestionAnswering, - load_bert_pt_weights_in_tf2, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, - load_gpt2_pt_weights_in_tf2, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer, TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, - load_openai_gpt_pt_weights_in_tf2, TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer, TFTransfoXLModel, TFTransfoXLLMHeadModel, - load_transfo_xl_pt_weights_in_tf2, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer, TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnsweringSimple, - load_xlnet_pt_weights_in_tf2, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer, TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple, - load_xlm_pt_weights_in_tf2, TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer, TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, - load_roberta_pt_weights_in_tf2, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer, TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TFDistilBertForQuestionAnswering, - load_distilbert_pt_weights_in_tf2, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel, TFCTRLLMHeadModel, - load_ctrl_pt_weights_in_tf2, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) # TF 2.0 <=> PyTorch conversion utilities -if is_tf_available() and is_torch_available(): - from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, - load_pytorch_checkpoint_in_tf2_model, - load_pytorch_weights_in_tf2_model, - load_pytorch_model_in_tf2_model, - load_tf2_checkpoint_in_pytorch_model, - load_tf2_weights_in_pytorch_model, - load_tf2_model_in_pytorch_model) +from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, + load_pytorch_checkpoint_in_tf2_model, + load_pytorch_weights_in_tf2_model, + load_pytorch_model_in_tf2_model, + load_tf2_checkpoint_in_pytorch_model, + load_tf2_weights_in_pytorch_model, + load_tf2_model_in_pytorch_model) if not is_tf_available() and not is_torch_available(): logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found." diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 112b15190f..9f79b85ef8 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -153,7 +153,7 @@ class PretrainedConfig(object): config = cls.from_json_file(resolved_config_file) if hasattr(config, 'pruned_heads'): - config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) + config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] @@ -164,7 +164,7 @@ class PretrainedConfig(object): for key in to_remove: kwargs.pop(key, None) - logger.info("Model config %s", config) + logger.info("Model config %s", str(config)) if return_unused_kwargs: return config, kwargs else: diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index 73878fc07d..e673b77dcc 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -24,15 +24,16 @@ import tensorflow as tf from transformers import is_torch_available, cached_path -from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, - TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, - OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, - RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) +from transformers import (load_pytorch_checkpoint_in_tf2_model, + BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, + XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, + XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, + TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, + OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, + RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) if is_torch_available(): import torch @@ -71,27 +72,27 @@ import logging logging.basicConfig(level=logging.INFO) MODEL_CLASSES = { - 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) + 'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) } def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): if model_type not in MODEL_CLASSES: raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) - config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] + config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] # Initialise TF model if config_file in aws_config_map: @@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file # Load weights from tf checkpoint if pytorch_checkpoint_path in aws_model_maps: pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) - tf_model = loading_fct(tf_model, pytorch_checkpoint_path) + # Load PyTorch checkpoint in tf2 model: + tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) if compare_with_pt_model: inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] @@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc if model_type not in MODEL_CLASSES: raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) - config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] + config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] if model_shortcut_names_or_path is None: model_shortcut_names_or_path = list(aws_model_maps.keys()) diff --git a/transformers/modeling_tf_bert.py b/transformers/modeling_tf_bert.py index 4de94751f8..afe9b2946b 100644 --- a/transformers/modeling_tf_bert.py +++ b/transformers/modeling_tf_bert.py @@ -30,7 +30,6 @@ import tensorflow as tf from .configuration_bert import BertConfig from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { } -def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def gelu(x): """ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when initially created. @@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel): """ config_class = BertConfig pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_bert_pt_weights_in_tf2 base_model_prefix = "bert" diff --git a/transformers/modeling_tf_ctrl.py b/transformers/modeling_tf_ctrl.py index 62f5d3cef4..95cc873448 100644 --- a/transformers/modeling_tf_ctrl.py +++ b/transformers/modeling_tf_ctrl.py @@ -27,20 +27,11 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"} -def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def angle_defn(pos, i, d_model_size): angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size)) return pos * angle_rates @@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel): config_class = CTRLConfig pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "transformer" - load_pt_weights = load_ctrl_pt_weights_in_tf2 CTRL_START_DOCSTRING = r""" CTRL model was proposed in diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index f9fe4ca9e9..188394816e 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -31,7 +31,6 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -66,14 +65,6 @@ def gelu_new(x): (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) return x * cdf -def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) - attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) - tf_inputs = [inputs_list, attns_list] - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - class TFEmbeddings(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super(TFEmbeddings, self).__init__(**kwargs) @@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel): """ config_class = DistilBertConfig pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_distilbert_pt_weights_in_tf2 base_model_prefix = "distilbert" diff --git a/transformers/modeling_tf_gpt2.py b/transformers/modeling_tf_gpt2.py index 883340cac9..4188b273ba 100644 --- a/transformers/modeling_tf_gpt2.py +++ b/transformers/modeling_tf_gpt2.py @@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer) from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",} -def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def gelu(x): """Gaussian Error Linear Unit. This is a smoother version of the RELU. @@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): """ config_class = GPT2Config pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_gpt2_pt_weights_in_tf2 base_model_prefix = "transformer" diff --git a/transformers/modeling_tf_openai.py b/transformers/modeling_tf_openai.py index 7521866c24..747c5171fd 100644 --- a/transformers/modeling_tf_openai.py +++ b/transformers/modeling_tf_openai.py @@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer) from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"} -def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def gelu(x): """Gaussian Error Linear Unit. This is a smoother version of the RELU. @@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): """ config_class = OpenAIGPTConfig pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_openai_gpt_pt_weights_in_tf2 base_model_prefix = "transformer" diff --git a/transformers/modeling_tf_pytorch_utils.py b/transformers/modeling_tf_pytorch_utils.py index 66caa95ec7..5a70d9a72b 100644 --- a/transformers/modeling_tf_pytorch_utils.py +++ b/transformers/modeling_tf_pytorch_utils.py @@ -25,8 +25,6 @@ import numpy logger = logging.getLogger(__name__) -DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): """ Convert a TF 2.0 model variable name in a pytorch model weight name. @@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a raise e if tf_inputs is None: - tf_inputs = tf.constant(DUMMY_INPUTS) + tf_inputs = tf_model.dummy_inputs if tf_inputs is not None: tfo = tf_model(tf_inputs, training=False) # Make sure model is built diff --git a/transformers/modeling_tf_roberta.py b/transformers/modeling_tf_roberta.py index 43747133ff..db62dd3014 100644 --- a/transformers/modeling_tf_roberta.py +++ b/transformers/modeling_tf_roberta.py @@ -26,7 +26,6 @@ import tensorflow as tf from .configuration_roberta import RobertaConfig from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new @@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5", } -def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - class TFRobertaEmbeddings(TFBertEmbeddings): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. @@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): """ config_class = RobertaConfig pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_roberta_pt_weights_in_tf2 base_model_prefix = "roberta" diff --git a/transformers/modeling_tf_transfo_xl.py b/transformers/modeling_tf_transfo_xl.py index df8c7e7dc9..a3e403ce06 100644 --- a/transformers/modeling_tf_transfo_xl.py +++ b/transformers/modeling_tf_transfo_xl.py @@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5", } -def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - class TFPositionalEmbedding(tf.keras.layers.Layer): def __init__(self, demb, **kwargs): super(TFPositionalEmbedding, self).__init__(**kwargs) @@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel): """ config_class = TransfoXLConfig pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_transfo_xl_pt_weights_in_tf2 base_model_prefix = "transformer" diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 06a333af37..3a576345f5 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -25,9 +25,11 @@ import tensorflow as tf from .configuration_utils import PretrainedConfig from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME +from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) +DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] class TFPreTrainedModel(tf.keras.Model): r""" Base class for all TF models. @@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model): """ config_class = None pretrained_model_archive_map = {} - load_pt_weights = lambda model, config, path: None base_model_prefix = "" + dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network def __init__(self, config, *inputs, **kwargs): super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) @@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model): if from_pt: # Load from a PyTorch checkpoint - return cls.load_pt_weights(model, resolved_archive_file) + return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) - inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) - ret = model(inputs, training=False) # build the network with dummy inputs + ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 model.load_weights(resolved_archive_file, by_name=True) - ret = model(inputs, training=False) # Make sure restore ops are run + ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run return model diff --git a/transformers/modeling_tf_xlm.py b/transformers/modeling_tf_xlm.py index 83cc37c6a7..84de1517ee 100644 --- a/transformers/modeling_tf_xlm.py +++ b/transformers/modeling_tf_xlm.py @@ -25,9 +25,8 @@ import numpy as np import tensorflow as tf from .configuration_xlm import XLMConfig -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { } -def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - # build the network - inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) - attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) - if tf_model.config.use_lang_emb and tf_model.config.n_langs > 1: - langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) - else: - langs_list = None - tf_inputs = [inputs_list, attns_list, langs_list] - tfo = tf_model(tf_inputs, training=False) - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def create_sinusoidal_embeddings(n_pos, dim, out): position_enc = np.array([ [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] @@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): """ config_class = XLMConfig pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_xlm_pt_weights_in_tf2 base_model_prefix = "transformer" + @property + def dummy_inputs(self): + # Sometimes XLM has language embeddings so don't forget to build them as well if needed + inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) + attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + if self.config.use_lang_emb and self.config.n_langs > 1: + langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + else: + langs_list = None + return [inputs_list, attns_list, langs_list] + XLM_START_DOCSTRING = r""" The XLM model was proposed in `Cross-lingual Language Model Pretraining`_ diff --git a/transformers/modeling_tf_xlnet.py b/transformers/modeling_tf_xlnet.py index 9370bd0915..904c2f4af0 100644 --- a/transformers/modeling_tf_xlnet.py +++ b/transformers/modeling_tf_xlnet.py @@ -30,7 +30,6 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer from .file_utils import add_start_docstrings -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { } -def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): - inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] - tf_inputs = tf.constant(inputs_list) - tfo = tf_model(tf_inputs, training=False) # build the network - return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) - - def gelu(x): """ Implementation of the gelu activation function. XLNet is using OpenAI GPT's gelu @@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel): """ config_class = XLNetConfig pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_xlnet_pt_weights_in_tf2 base_model_prefix = "transformer" diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index 2b66757c28..298dcf3bdc 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -17,8 +17,10 @@ from __future__ import division from __future__ import print_function import copy +import sys import os import shutil +import tempfile import json import random import uuid @@ -31,6 +33,7 @@ from transformers import is_torch_available if is_torch_available(): import torch + import numpy as np from transformers import (PretrainedConfig, PreTrainedModel, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, @@ -38,6 +41,20 @@ if is_torch_available(): else: pytestmark = pytest.mark.skip("Require Torch") +if sys.version_info[0] == 2: + import cPickle as pickle + + class TemporaryDirectory(object): + """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" + def __enter__(self): + self.name = tempfile.mkdtemp() + return self.name + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self.name) +else: + import pickle + TemporaryDirectory = tempfile.TemporaryDirectory + unicode = str def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -57,6 +74,23 @@ class CommonTestCases: test_resize_embeddings = True test_head_masking = True + def test_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.eval() + with torch.no_grad(): + outputs = model(**inputs_dict) + + with TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + with torch.no_grad(): + after_outputs = model(**inputs_dict) + max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy())) + self.assertLessEqual(max_diff, 1e-5) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()