Fix DETA save_pretrained (#30326)

* Add class_embed to tied weights for DETA

* Fix test_tied_weights_keys for DETA model

* Replace error raise with assert statement
This commit is contained in:
Pavel Iakubovskii
2024-04-22 17:11:13 +01:00
committed by GitHub
parent 6c7335e053
commit 13b3b90ab1
3 changed files with 44 additions and 3 deletions

View File

@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys: