Fix flaky SwitchTransformersModelTest::test_training_gradient (#35587)

* fix

* Update tests/models/switch_transformers/test_modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-09 15:36:22 +01:00
committed by GitHub
parent eb4579cf43
commit 82dd6c14bb
2 changed files with 8 additions and 3 deletions

View File

@@ -221,6 +221,8 @@ class ModelTesterMixin:
test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
test_all_params_have_gradient = True
is_encoder_decoder = False
has_attentions = True
_is_composite = False
@@ -895,9 +897,10 @@ class ModelTesterMixin:
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
if self.test_all_params_have_gradient:
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
def test_training(self):
if not self.model_tester.is_training: