Sharding fails in TF when absolute scope was modified if . in layer name (#19124)
* simplify loop * fix layer map split * update * update for special variables * add rag test * fixup * revert change : for next PR
This commit is contained in:
@@ -77,9 +77,11 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
)
|
||||
from transformers.generation_tf_utils import (
|
||||
@@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_special_layer_name_shardind(self):
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user