Sharding fails in TF when absolute scope was modified if . in layer name (#19124)
* simplify loop * fix layer map split * update * update for special variables * add rag test * fixup * revert change : for next PR
This commit is contained in:
@@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
|||||||
|
|
||||||
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
||||||
# the weight, we have to get rid of the first prefix of the name of the layer.
|
# the weight, we have to get rid of the first prefix of the name of the layer.
|
||||||
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
|
model_keys = set()
|
||||||
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
|
model_layer_map = dict()
|
||||||
|
for i, k in enumerate(model.weights):
|
||||||
|
if "model." in k.name or len(k.name.split("/")) == 1:
|
||||||
|
layer_name = k.name
|
||||||
|
else:
|
||||||
|
layer_name = "/".join(k.name.split("/")[1:])
|
||||||
|
model_keys.add(layer_name)
|
||||||
|
model_layer_map[layer_name] = i
|
||||||
|
|
||||||
for shard_file in shard_files:
|
for shard_file in shard_files:
|
||||||
state_dict = tf.io.read_file(shard_file)
|
state_dict = tf.io.read_file(shard_file)
|
||||||
@@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||||
save_attributes_to_hdf5_group(
|
layers = []
|
||||||
shard_file,
|
|
||||||
"layer_names",
|
|
||||||
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in sorted(shard, key=lambda x: x.name):
|
for layer in sorted(shard, key=lambda x: x.name):
|
||||||
|
if "model." in layer.name or len(layer.name.split("/")) == 1:
|
||||||
|
layer_name = layer.name
|
||||||
|
print(layer_name)
|
||||||
|
else:
|
||||||
|
layer_name = "/".join(layer.name.split("/")[1:])
|
||||||
param_dset = shard_file.create_dataset(
|
param_dset = shard_file.create_dataset(
|
||||||
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
|
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
|
||||||
)
|
)
|
||||||
param_dset[:] = layer.numpy()
|
param_dset[:] = layer.numpy()
|
||||||
|
layers.append(layer_name.encode("utf8"))
|
||||||
|
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
self._upload_modified_files(
|
self._upload_modified_files(
|
||||||
|
|||||||
@@ -77,9 +77,11 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
RagRetriever,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
|
TFRagModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
)
|
)
|
||||||
from transformers.generation_tf_utils import (
|
from transformers.generation_tf_utils import (
|
||||||
@@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_special_layer_name_shardind(self):
|
||||||
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||||
|
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
def test_checkpoint_sharding_local(self):
|
def test_checkpoint_sharding_local(self):
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user