Fix flaky SwitchTransformersModelTest::test_training_gradient (#35587)

* fix

* Update tests/models/switch_transformers/test_modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-09 15:36:22 +01:00
committed by GitHub
parent eb4579cf43
commit 82dd6c14bb
2 changed files with 8 additions and 3 deletions

View File

@@ -576,6 +576,8 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
test_torchscript = False
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# `SwitchTransformers` is a MOE in which not all experts will get gradients because they are not all used in a single forward pass
test_all_params_have_gradient = False
def setUp(self):
self.model_tester = SwitchTransformersModelTester(self)

View File

@@ -221,6 +221,8 @@ class ModelTesterMixin:
test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
test_all_params_have_gradient = True
is_encoder_decoder = False
has_attentions = True
_is_composite = False
@@ -895,6 +897,7 @@ class ModelTesterMixin:
loss.backward()
optimizer.step()
if self.test_all_params_have_gradient:
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")