Fix DETA save_pretrained (#30326)
* Add class_embed to tied weights for DETA * Fix test_tied_weights_keys for DETA model * Replace error raise with assert statement
This commit is contained in:
committed by
GitHub
parent
6c7335e053
commit
13b3b90ab1
@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
if not any(re.search(key, p) for group in tied_params for p in group):
|
||||
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||
|
||||
# Removed tied weights found from tied params -> there should only be one left after
|
||||
for key in tied_weight_keys:
|
||||
|
||||
Reference in New Issue
Block a user