From 2bd2de62c941f23dc744b8bbffafde609e68d1d5 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:34:33 +0200 Subject: [PATCH] Sharding fails in TF when absolute scope was modified if `.` in layer name (#19124) * simplify loop * fix layer map split * update * update for special variables * add rag test * fixup * revert change : for next PR --- src/transformers/modeling_tf_utils.py | 27 ++++++++++++++++++--------- tests/test_modeling_tf_common.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 85fa1f3b73..e86de7c914 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load # the weight, we have to get rid of the first prefix of the name of the layer. - model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights) - model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)} + model_keys = set() + model_layer_map = dict() + for i, k in enumerate(model.weights): + if "model." in k.name or len(k.name.split("/")) == 1: + layer_name = k.name + else: + layer_name = "/".join(k.name.split("/")[1:]) + model_keys.add(layer_name) + model_layer_map[layer_name] = i for shard_file in shard_files: state_dict = tf.io.read_file(shard_file) @@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ) for shard_file, shard in shards.items(): with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: - save_attributes_to_hdf5_group( - shard_file, - "layer_names", - ["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard], - ) - + layers = [] for layer in sorted(shard, key=lambda x: x.name): + if "model." in layer.name or len(layer.name.split("/")) == 1: + layer_name = layer.name + print(layer_name) + else: + layer_name = "/".join(layer.name.split("/")[1:]) param_dset = shard_file.create_dataset( - "/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype + layer_name, layer.numpy().shape, dtype=layer.numpy().dtype ) param_dset[:] = layer.numpy() + layers.append(layer_name.encode("utf8")) + save_attributes_to_hdf5_group(shard_file, "layer_names", layers) if push_to_hub: self._upload_modified_files( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index e054514d74..a82b5d51b2 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -77,9 +77,11 @@ if is_tf_available(): TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, + RagRetriever, TFAutoModel, TFAutoModelForSequenceClassification, TFBertModel, + TFRagModel, TFSharedEmbeddings, ) from transformers.generation_tf_utils import ( @@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase): }, ) + @slow + def test_special_layer_name_shardind(self): + retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) + model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever) + + with tempfile.TemporaryDirectory() as tmp_dir: + for max_size in ["150kB", "150kiB", "200kB", "200kiB"]: + model.save_pretrained(tmp_dir, max_shard_size=max_size) + ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever) + for p1, p2 in zip(model.weights, ref_model.weights): + assert np.allclose(p1.numpy(), p2.numpy()) + def test_checkpoint_sharding_local(self): model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")