From acfb714bdf6a7fdc52c2f171783c878bea6be4ea Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 28 Feb 2023 18:41:34 +0000 Subject: [PATCH] Improve TF weight loading, especially PT crossloading (#21792) * First commit for the improved PT-TF weight loading * Remove workarounds from TFEncoderDecoder tests * Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder * make fixup * First attempt at visionencoderdecoder * Disable tensorfloat32 in tests to get consistent outputs * Quick fix to tf_vision_encoder_decoder tests * make fixup * Update Blenderbot tests * Remove unused arg in modeling_tf_opt * load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False. * Support prefixes when loading sharded TF checkpoints * make fixup * Add test to load sharded models with a weight prefix * Fix sharded weight loading test * Add a test for transfer from a sharded checkpoint * make fixup * Add test to check that crossloading from PT with a prefix works * Refactor from_pretrained in the encoderdecoder classes * Refactor from_pretrained in the encoderdecoder classes * missmatched -> mismatched * Explicitly check for None * No comments showing my very impressive and attractive knowledge of Py3.9+ * Disable TF32 across all TF tests --- src/transformers/modeling_tf_pytorch_utils.py | 58 +++++++++++++-- src/transformers/modeling_tf_utils.py | 58 ++++++++++----- .../modeling_tf_encoder_decoder.py | 71 ++++-------------- .../models/opt/modeling_tf_opt.py | 2 +- .../modeling_tf_vision_encoder_decoder.py | 73 ++++--------------- ...test_modeling_tf_vision_encoder_decoder.py | 8 +- tests/test_modeling_tf_common.py | 25 +++++++ 7 files changed, 147 insertions(+), 148 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 5465da7427..402159cc6f 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -39,7 +39,9 @@ class TransposeType(ExplicitEnum): CONV2D = "conv2d" -def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None): +def convert_tf_weight_name_to_pt_weight_name( + tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None +): """ Convert a TF 2.0 model variable name in a pytorch model weight name. @@ -54,6 +56,14 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be transposed with regards to each other """ + if name_scope is not None: + if not tf_name.startswith(name_scope): + raise ValueError( + f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error " + "in Transformers, so (unless you were doing something really evil) please open an issue to report it!" + ) + tf_name = tf_name[len(name_scope) :] + tf_name = tf_name.lstrip("/") tf_name = tf_name.replace(":0", "") # device ids tf_name = re.sub( r"/[^/]*___([^/]*)/", r"/\1/", tf_name @@ -144,7 +154,13 @@ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf def load_pytorch_checkpoint_in_tf2_model( - tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False + tf_model, + pytorch_checkpoint_path, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, ): """Load pytorch checkpoints in a TF 2.0 model""" try: @@ -176,6 +192,8 @@ def load_pytorch_checkpoint_in_tf2_model( tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, ) @@ -189,7 +207,13 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi def load_pytorch_weights_in_tf2_model( - tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, ): """Load pytorch state_dict in a TF 2.0 model.""" try: @@ -209,11 +233,19 @@ def load_pytorch_weights_in_tf2_model( tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, ) def load_pytorch_state_dict_in_tf2_model( - tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, ): """Load a pytorch state_dict in a TF 2.0 model.""" import tensorflow as tf @@ -227,8 +259,11 @@ def load_pytorch_state_dict_in_tf2_model( if tf_inputs is None: tf_inputs = tf_model.dummy_inputs + if _prefix is None: + _prefix = "" if tf_inputs is not None: - tf_model(tf_inputs, training=False) # Make sure model is built + with tf.name_scope(_prefix): + tf_model(tf_inputs, training=False) # Make sure model is built # Adapt state dict - TODO remove this and update the AWS weights files instead # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -249,8 +284,10 @@ def load_pytorch_state_dict_in_tf2_model( for old_key, new_key in zip(old_keys, new_keys): pt_state_dict[new_key] = pt_state_dict.pop(old_key) - # Make sure we are able to load PyTorch base models as well as derived models (with heads) - # TF models always have a prefix, some of PyTorch models (base ones) don't + # Matt: All TF models store the actual model stem in a MainLayer class, including the base model. + # In PT, the derived models (with heads) use the base model class as the stem instead, and the base model + # just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one + # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that. start_prefix_to_remove = "" if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()): start_prefix_to_remove = tf_model.base_model_prefix + "." @@ -263,8 +300,13 @@ def load_pytorch_state_dict_in_tf2_model( for symbolic_weight in symbolic_weights: sw_name = symbolic_weight.name name, transpose = convert_tf_weight_name_to_pt_weight_name( - sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape + sw_name, + start_prefix_to_remove=start_prefix_to_remove, + tf_weight_shape=symbolic_weight.shape, + name_scope=_prefix, ) + if tf_to_pt_weight_rename is not None: + name = tf_to_pt_weight_rename(name) # Find associated numpy array in pytorch model state dict if name not in pt_state_dict: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c469c13ff0..be0c6b313b 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -707,7 +707,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"): return shards, index -def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True): +def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None): """ This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and shapes. @@ -729,32 +729,35 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s """ # Load the index - missing_keys = [] unexpected_keys = set() saved_keys = set() - missmatched_keys = set() + mismatched_keys = set() # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load # the weight, we have to get rid of the first prefix of the name of the layer. model_keys = set() model_layer_map = {} for i, k in enumerate(model.weights): - if "model." in k.name or len(k.name.split("/")) == 1: - layer_name = k.name - else: - layer_name = "/".join(k.name.split("/")[1:]) + layer_name = k.name + if _prefix is not None and layer_name.startswith(_prefix): + layer_name = layer_name[len(_prefix) :] + layer_name = layer_name.lstrip("/") + if not ("model." in layer_name or len(layer_name.split("/")) == 1): + layer_name = "/".join(layer_name.split("/")[1:]) model_keys.add(layer_name) model_layer_map[layer_name] = i for shard_file in shard_files: - state_dict = tf.io.read_file(shard_file) - saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard( - model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes + saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard( + model, + model_layer_map, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, ) saved_keys.update(saved_weight_names_set) unexpected_keys.update(unexpected_keys_set) - missmatched_keys.update(missmatched_keys_set) - del state_dict + mismatched_keys.update(mismatched_keys_set) gc.collect() missing_keys = model_keys - saved_keys @@ -768,10 +771,10 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - return missing_keys, unexpected_keys, missmatched_keys + return missing_keys, unexpected_keys, mismatched_keys -def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False): +def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): """ Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys. @@ -783,11 +786,11 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch Returns: `tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the - shard file), one for the missmatched layers, and another one for the unexpected layers. + shard file), one for the mismatched layers, and another one for the unexpected layers. """ saved_weight_names_set = set() saved_weights = {} - missmatched_keys = set() + mismatched_keys = set() unexpected_keys = set() # Read the H5 file try: @@ -822,7 +825,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) except ValueError as e: if ignore_mismatched_sizes: - missmatched_keys.add( + mismatched_keys.add( (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) ) continue @@ -836,7 +839,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch K.batch_set_value(weight_value_tuples) - return saved_weight_names_set, unexpected_keys, missmatched_keys + return saved_weight_names_set, unexpected_keys, mismatched_keys except Exception as e: try: @@ -2458,6 +2461,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. + tf_to_pt_weight_rename (`Callable`, *optional*): + A function that is called to transform the names of weights during the PyTorch to TensorFlow + crossloading process. This is not necessary for most models, but is useful to allow composite models to + be crossloaded correctly. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -2506,6 +2513,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu from_auto_class = kwargs.pop("_from_auto", False) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) + tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) if trust_remote_code is True: logger.warning( @@ -2745,7 +2753,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model( - model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info + model, + resolved_archive_file, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, ) # we might need to extend the variable scope for composite models @@ -2761,7 +2774,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu state_dict = safe_load_file(resolved_archive_file) # Load from a PyTorch checkpoint return load_pytorch_state_dict_in_tf2_model( - model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info + model, + state_dict, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, ) # 'by_name' allow us to do transfer learning by skipping/adding layers @@ -2775,6 +2792,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, ) else: missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 5c301ed4d0..be6d03a131 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -15,9 +15,7 @@ """ Classes to support TF Encoder-Decoder architectures""" -import gc -import os -import tempfile +import re import warnings from typing import Optional, Tuple, Union @@ -306,46 +304,23 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): >>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") ```""" + # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models + # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. + # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption + # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's + # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - from_pt = kwargs.pop("from_pt", False) - if from_pt: - import torch + if kwargs.get("from_pt", False): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + encoder_model_type = config.encoder.model_type - from transformers import EncoderDecoderModel - - # a workaround to load from pytorch checkpoint - _model = EncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - config = _model.config - - with tempfile.TemporaryDirectory() as tmpdirname: - encoder_dir = os.path.join(tmpdirname, "encoder") - decoder_dir = os.path.join(tmpdirname, "decoder") - _model.encoder.save_pretrained(encoder_dir) - _model.decoder.save_pretrained(decoder_dir) - - if hasattr(_model, "enc_to_dec_proj"): - enc_to_dec_proj_kernel = tf.transpose( - tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0) - ) - enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy()) - - del _model - gc.collect() - torch.cuda.empty_cache() - - model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True - ) - # This is only for copying some specific attributes of this particular model. - model.config = config - - if hasattr(model, "enc_to_dec_proj"): - model(model.dummy_inputs) - model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel) - model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias) - - return model + def tf_to_pt_weight_rename(tf_weight): + if "encoder" in tf_weight and "decoder" not in tf_weight: + return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) + else: + return tf_weight + kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod @@ -451,14 +426,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) - # This is necessary to make `from_pretrained` following `save_pretrained` work correctly - if kwargs_encoder.get("from_pt", None): - del kwargs_encoder["from_pt"] - with tempfile.TemporaryDirectory() as tmp_dirname: - encoder.save_pretrained(tmp_dirname) - del encoder - encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder) - decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: @@ -493,14 +460,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - # This is necessary to make `from_pretrained` following `save_pretrained` work correctly - if kwargs_decoder.get("from_pt", None): - del kwargs_decoder["from_pt"] - with tempfile.TemporaryDirectory() as tmp_dirname: - decoder.save_pretrained(tmp_dirname) - del decoder - decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder) - # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. if encoder.name != "encoder": raise ValueError("encoder model must be created with the name `encoder`.") diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py index 2fcbd444a1..cd34130228 100644 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ b/src/transformers/models/opt/modeling_tf_opt.py @@ -486,7 +486,7 @@ OPT_INPUTS_DOCSTRING = r""" class TFOPTDecoder(tf.keras.layers.Layer): config_class = OPTConfig - def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs): + def __init__(self, config: OPTConfig, **kwargs): super().__init__(**kwargs) self.config = config self.padding_idx = config.pad_token_id diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index bec9a4f1ee..5af7c195ff 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -15,9 +15,7 @@ """ Classes to support TF Vision-Encoder-Text-Decoder architectures""" -import gc -import os -import tempfile +import re import warnings from typing import Optional, Tuple, Union @@ -320,46 +318,23 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos >>> assert preds == ["a cat laying on top of a couch next to another cat"] ```""" + # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models + # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. + # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption + # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's + # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - from_pt = kwargs.pop("from_pt", False) - if from_pt: - import torch + if kwargs.get("from_pt", False): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + encoder_model_type = config.encoder.model_type - from transformers import VisionEncoderDecoderModel - - # a workaround to load from pytorch checkpoint - _model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - config = _model.config - - with tempfile.TemporaryDirectory() as tmpdirname: - encoder_dir = os.path.join(tmpdirname, "encoder") - decoder_dir = os.path.join(tmpdirname, "decoder") - _model.encoder.save_pretrained(encoder_dir) - _model.decoder.save_pretrained(decoder_dir) - - if hasattr(_model, "enc_to_dec_proj"): - enc_to_dec_proj_kernel = tf.transpose( - tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0) - ) - enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy()) - - del _model - gc.collect() - torch.cuda.empty_cache() - - model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True - ) - # This is only for copying some specific attributes of this particular model. - model.config = config - - if hasattr(model, "enc_to_dec_proj"): - model(model.dummy_inputs) - model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel) - model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias) - - return model + def tf_to_pt_weight_rename(tf_weight): + if "encoder" in tf_weight and "decoder" not in tf_weight: + return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) + else: + return tf_weight + kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod @@ -466,15 +441,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) - # Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model. - # See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313 - if kwargs_encoder.get("from_pt", None): - del kwargs_encoder["from_pt"] - with tempfile.TemporaryDirectory() as tmp_dirname: - encoder.save_pretrained(tmp_dirname) - del encoder - encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder) - decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: @@ -509,15 +475,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - # Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model. - # See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313 - if kwargs_decoder.get("from_pt", None): - del kwargs_decoder["from_pt"] - with tempfile.TemporaryDirectory() as tmp_dirname: - decoder.save_pretrained(tmp_dirname) - del decoder - decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder) - # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. if encoder.name != "encoder": raise ValueError("encoder model must be created with the name `encoder`.") diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index 498f22b17e..1e594f5de5 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase): self.assertLessEqual(max_diff, 1e-4) def generate_step(pixel_values): - outputs = model.generate( - pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True - ) + outputs = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True) output_ids = outputs.sequences preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] - return preds, outputs.scores.numpy() + return preds - preds, scores = generate_step(pixel_values) + preds = generate_step(pixel_values) # should produce # ["a cat laying on top of a couch next to another cat"] diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index afd74411be..9cfd314dfc 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -90,6 +90,7 @@ if is_tf_available(): TFAutoModel, TFAutoModelForSequenceClassification, TFBertForMaskedLM, + TFBertForSequenceClassification, TFBertModel, TFRagModel, TFSharedEmbeddings, @@ -107,6 +108,8 @@ if is_tf_available(): from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs from transformers.tf_utils import stable_softmax + tf.config.experimental.enable_tensor_float_32_execution(False) + if _tf_gpu_memory_limit is not None: gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: @@ -2140,6 +2143,18 @@ class UtilsFunctionsTest(unittest.TestCase): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + def test_sharded_checkpoint_with_prefix(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b") + sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b") + for p1, p2 in zip(model.weights, sharded_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + self.assertTrue(p1.name.startswith("a/b/")) + self.assertTrue(p2.name.startswith("a/b/")) + + def test_sharded_checkpoint_transfer(self): + # If this doesn't throw an error then the test passes + TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded") + @is_pt_tf_cross_test def test_checkpoint_sharding_local_from_pt(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -2150,6 +2165,16 @@ class UtilsFunctionsTest(unittest.TestCase): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + @is_pt_tf_cross_test + def test_checkpoint_loading_with_prefix_from_pt(self): + model = TFBertModel.from_pretrained( + "hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b" + ) + ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True) + for p1, p2 in zip(model.weights, ref_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + self.assertTrue(p1.name.startswith("a/b/")) + @is_pt_tf_cross_test def test_checkpoint_sharding_hub_from_pt(self): model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)