From c34a525d2faea2976fbeeabbaaae929d05f8d8a7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 May 2023 19:04:27 +0200 Subject: [PATCH] Proposed fix for TF example now running on safetensors. (#23208) * Proposed fix for TF example now running on safetensors. * Adding more warnings and returning keys. * Trigger CI * Trigger CI --------- Co-authored-by: Sylvain Gugger --- .../tensorflow/test_tensorflow_examples.py | 1 - src/transformers/modeling_tf_pytorch_utils.py | 35 +++++++++++++++++-- src/transformers/modeling_tf_utils.py | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/examples/tensorflow/test_tensorflow_examples.py b/examples/tensorflow/test_tensorflow_examples.py index d5ae4c71b8..956209baad 100644 --- a/examples/tensorflow/test_tensorflow_examples.py +++ b/examples/tensorflow/test_tensorflow_examples.py @@ -297,7 +297,6 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertGreaterEqual(result["bleu"], 30) - @skip("Fix me Matt") def test_run_image_classification(self): tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 402159cc6f..3b1c030699 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -246,6 +246,7 @@ def load_pytorch_state_dict_in_tf2_model( output_loading_info=False, _prefix=None, tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, ): """Load a pytorch state_dict in a TF 2.0 model.""" import tensorflow as tf @@ -297,6 +298,7 @@ def load_pytorch_state_dict_in_tf2_model( weight_value_tuples = [] all_pytorch_weights = set(pt_state_dict.keys()) missing_keys = [] + mismatched_keys = [] for symbolic_weight in symbolic_weights: sw_name = symbolic_weight.name name, transpose = convert_tf_weight_name_to_pt_weight_name( @@ -319,7 +321,18 @@ def load_pytorch_state_dict_in_tf2_model( continue raise AttributeError(f"{name} not found in PyTorch model") - array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) + try: + array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) + except tf.errors.InvalidArgumentError as e: + if not ignore_mismatched_sizes: + error_msg = str(e) + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise tf.errors.InvalidArgumentError(error_msg) + else: + mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape)) + continue tf_loaded_numel += tensor_size(array) @@ -367,8 +380,26 @@ def load_pytorch_state_dict_in_tf2_model( f"you can already use {tf_model.__class__.__name__} for predictions without further training." ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint" + f" are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } return tf_model, loading_info return tf_model diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f48651a6e9..35c526379c 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2820,6 +2820,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu allow_missing_keys=True, output_loading_info=output_loading_info, _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, ) # 'by_name' allow us to do transfer learning by skipping/adding layers