Sharding fails in TF when absolute scope was modified if . in layer name (#19124)

* simplify loop

* fix layer map split

* update

* update for special variables

* add rag test

* fixup

* revert change : for next PR
This commit is contained in:
Arthur
2022-10-14 18:34:33 +02:00
committed by GitHub
parent 614f7d28a8
commit 2bd2de62c9
2 changed files with 32 additions and 9 deletions

View File

@@ -77,9 +77,11 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
RagRetriever,
TFAutoModel,
TFAutoModelForSequenceClassification,
TFBertModel,
TFRagModel,
TFSharedEmbeddings,
)
from transformers.generation_tf_utils import (
@@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
},
)
@slow
def test_special_layer_name_shardind(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
with tempfile.TemporaryDirectory() as tmp_dir:
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")