From 84c9bf74212728d337571e9b830ea4eabae34a0c Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 5 Dec 2022 16:56:17 +0100 Subject: [PATCH] cross platform from_pretrained (#20538) * add support for `from_pt` * add tf_flax utility file * Update src/transformers/modeling_tf_flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove flax related modifications * add test * remove FLAX related commits * fixup * remove safetensor todos * revert deletion Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_tf_utils.py | 22 ++++++++++++++++------ tests/test_modeling_tf_common.py | 8 ++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index fa6df4e40a..a642e883b0 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,7 @@ from .utils import ( SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ModelOutput, @@ -2392,7 +2393,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu save directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. - from_pt: (`bool`, *optional*, defaults to `False`): + from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch state_dict save file (see docstring of `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): @@ -2531,7 +2532,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): + if is_local: if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint in priority if from_pt archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) @@ -2559,7 +2560,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) is_sharded = True # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( + os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + ): raise EnvironmentError( f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " @@ -2630,6 +2633,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ) if resolved_archive_file is not None: is_sharded = True + if resolved_archive_file is None and filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True if resolved_archive_file is None: # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error # message. @@ -2646,8 +2656,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ) else: raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," + f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" ) except EnvironmentError: @@ -2661,7 +2671,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" " from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." + f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" ) if is_local: logger.info(f"loading weights file {archive_file}") diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 679776dd8c..6908c73c99 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + @is_pt_tf_cross_test + def test_checkpoint_sharding_hub_from_pt(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + # the model above is the same as the model below, just a sharded pytorch version. + ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + for p1, p2 in zip(model.weights, ref_model.weights): + assert np.allclose(p1.numpy(), p2.numpy()) + def test_shard_checkpoint(self): # This is the model we will use, total size 340,000 bytes. model = tf.keras.Sequential(