From 9cf7b23b9bda5ae0e827993e8154d17065ef8dab Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 5 Oct 2020 15:58:45 +0200 Subject: [PATCH] Custom TF weights loading (#7422) * First try * Fix TF utils * Handle authorized unexpected keys when loading weights * Add several more authorized unexpected keys * Apply style * Fix test * Address Patrick's comments. * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply style * Make return_dict the default behavior and display a warning message * Revert * Replace wrong keyword * Revert code * Add forgot key * Fix bug in loading PT models from a TF one. * Fix sort * Add a test for custom load weights in BERT * Apply style * Remove unused import Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_tf_bert.py | 4 + src/transformers/modeling_tf_pytorch_utils.py | 18 ++- src/transformers/modeling_tf_utils.py | 127 +++++++++++++++--- tests/test_modeling_tf_bert.py | 17 ++- 4 files changed, 139 insertions(+), 27 deletions(-) diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index eb8c1ad35b..5b7a333f47 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -854,6 +854,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): + authorized_unexpected_keys = [r"pooler"] authorized_missing_keys = [r"pooler"] def __init__(self, config, *inputs, **kwargs): @@ -939,6 +940,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): + authorized_unexpected_keys = [r"pooler"] authorized_missing_keys = [r"pooler"] def __init__(self, config, *inputs, **kwargs): @@ -1286,6 +1288,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ) class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): + authorized_unexpected_keys = [r"pooler"] authorized_missing_keys = [r"pooler"] def __init__(self, config, *inputs, **kwargs): @@ -1369,6 +1372,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ) class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): + authorized_unexpected_keys = [r"pooler"] authorized_missing_keys = [r"pooler"] def __init__(self, config, *inputs, **kwargs): diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index ca1bf4813e..713b426d3e 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -177,6 +177,13 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a elif len(symbolic_weight.shape) > len(array.shape): array = numpy.expand_dims(array, axis=0) + if list(symbolic_weight.shape) != list(array.shape): + try: + array = numpy.reshape(array, symbolic_weight.shape) + except AssertionError as e: + e.args += (symbolic_weight.shape, array.shape) + raise e + try: assert list(symbolic_weight.shape) == list(array.shape) except AssertionError as e: @@ -251,6 +258,8 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs import transformers + from .modeling_tf_utils import load_tf_weights + logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) # Instantiate and load the associated TF 2.0 model @@ -264,7 +273,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs if tf_inputs is not None: tf_model(tf_inputs, training=False) # Make sure model is built - tf_model.load_weights(tf_checkpoint_path, by_name=True) + load_tf_weights(tf_model, tf_checkpoint_path) return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) @@ -332,6 +341,13 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F elif len(pt_weight.shape) > len(array.shape): array = numpy.expand_dims(array, axis=0) + if list(pt_weight.shape) != list(array.shape): + try: + array = numpy.reshape(array, pt_weight.shape) + except AssertionError as e: + e.args += (pt_weight.shape, array.shape) + raise e + try: assert list(pt_weight.shape) == list(array.shape) except AssertionError as e: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 89462bcddf..11f361590a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Union import h5py import numpy as np import tensorflow as tf +from tensorflow.python.keras import backend as K from tensorflow.python.keras.saving import hdf5_format from .configuration_utils import PretrainedConfig from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .generation_tf_utils import TFGenerationMixin -from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .utils import logging @@ -216,6 +216,91 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): """ +def detect_tf_missing_unexpected_layers(model, resolved_archive_file): + """ + Detect missing and unexpected layers. + + Args: + model (:obj:`tf.keras.models.Model`): + The model to load the weights into. + resolved_archive_file (:obj:`str`): + The location of the H5 file. + + Returns: + Two lists, one for the missing layers, and another one for the unexpected layers. + """ + missing_layers = [] + unexpected_layers = [] + + with h5py.File(resolved_archive_file, "r") as f: + saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + model_layer_names = set(layer.name for layer in model.layers) + missing_layers = list(model_layer_names - saved_layer_names) + unexpected_layers = list(saved_layer_names - model_layer_names) + + for layer in model.layers: + if layer.name in saved_layer_names: + g = f[layer.name] + saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + saved_weight_names_set = set( + "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names + ) + symbolic_weights = layer.trainable_weights + layer.non_trainable_weights + symbolic_weights_names = set( + "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights + ) + missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) + unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + + return missing_layers, unexpected_layers + + +def load_tf_weights(model, resolved_archive_file): + """ + Load the TF weights from a H5 file. + + Args: + model (:obj:`tf.keras.models.Model`): + The model to load the weights into. + resolved_archive_file (:obj:`str`): + The location of the H5 file. + """ + with h5py.File(resolved_archive_file, "r") as f: + saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + weight_value_tuples = [] + + for layer in model.layers: + if layer.name in saved_layer_names: + g = f[layer.name] + saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + symbolic_weights = layer.trainable_weights + layer.non_trainable_weights + saved_weight_names_values = {} + + for weight_name in saved_weight_names: + name = "/".join(weight_name.split("/")[1:]) + saved_weight_names_values[name] = np.asarray(g[weight_name]) + + for symbolic_weight in symbolic_weights: + splited_layers = symbolic_weight.name.split("/")[1:] + symbolic_weight_name = "/".join(splited_layers) + + if symbolic_weight_name in saved_weight_names_values: + saved_weight_value = saved_weight_names_values[symbolic_weight_name] + + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except AssertionError as e: + e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape) + raise e + else: + array = saved_weight_value + + weight_value_tuples.append((symbolic_weight, array)) + + K.batch_set_value(weight_value_tuples) + + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): r""" Base class for all TF models. @@ -231,10 +316,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore + from the model when loading the model weights (and avoid unnecessary warnings). + - **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore + from the weights when loading the model weights (and avoid unnecessary warnings). """ config_class = None base_model_prefix = "" authorized_missing_keys = None + authorized_unexpected_keys = None @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: @@ -604,6 +694,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): model = cls(config, *model_args, **model_kwargs) if from_pt: + from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model + # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) @@ -613,7 +705,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - model.load_weights(resolved_archive_file, by_name=True) + load_tf_weights(model, resolved_archive_file) except OSError: raise OSError( "Unable to load weights from h5 file. " @@ -622,23 +714,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): model(model.dummy_inputs, training=False) # Make sure restore ops are run - # Check if the models are the same to output loading informations - with h5py.File(resolved_archive_file, "r") as f: - if "layer_names" not in f.attrs and "model_weights" in f: - f = f["model_weights"] - hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - model_layer_names = set(layer.name for layer in model.layers) - missing_keys = list(model_layer_names - hdf5_layer_names) - unexpected_keys = list(hdf5_layer_names - model_layer_names) - error_msgs = [] + missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file) if cls.authorized_missing_keys is not None: for pat in cls.authorized_missing_keys: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + if cls.authorized_unexpected_keys is not None: + for pat in cls.authorized_unexpected_keys: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if len(unexpected_keys) > 0: logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" @@ -646,25 +734,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: - logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.warning( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) - if len(error_msgs) > 0: - raise RuntimeError( - "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) - ) + if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} + loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} + return model, loading_info return model diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index ed25c4c8e5..7fbdb08c87 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -17,7 +17,7 @@ import unittest from transformers import BertConfig, is_tf_available -from transformers.testing_utils import require_tf, slow +from transformers.testing_utils import require_tf from .test_configuration_common import ConfigTester from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor @@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) - @slow def test_model_from_pretrained(self): - # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - for model_name in ["bert-base-uncased"]: - model = TFBertModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random") + self.assertIsNotNone(model) + + def test_custom_load_tf_weights(self): + model, output_loading_info = TFBertForTokenClassification.from_pretrained( + "jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True + ) + self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"]) + for layer in output_loading_info["missing_keys"]: + self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])