Fix flaky SwitchTransformersModelTest::test_training_gradient (#35587)

* fix

* Update tests/models/switch_transformers/test_modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-09 15:36:22 +01:00
committed by GitHub
parent eb4579cf43
commit 82dd6c14bb
2 changed files with 8 additions and 3 deletions

View File

@@ -576,6 +576,8 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
test_torchscript = False test_torchscript = False
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests # The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9] model_split_percents = [0.5, 0.8, 0.9]
# `SwitchTransformers` is a MOE in which not all experts will get gradients because they are not all used in a single forward pass
test_all_params_have_gradient = False
def setUp(self): def setUp(self):
self.model_tester = SwitchTransformersModelTester(self) self.model_tester = SwitchTransformersModelTester(self)

View File

@@ -221,6 +221,8 @@ class ModelTesterMixin:
test_mismatched_shapes = True test_mismatched_shapes = True
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
test_all_params_have_gradient = True
is_encoder_decoder = False is_encoder_decoder = False
has_attentions = True has_attentions = True
_is_composite = False _is_composite = False
@@ -895,9 +897,10 @@ class ModelTesterMixin:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
for k, v in model.named_parameters(): if self.test_all_params_have_gradient:
if v.requires_grad: for k, v in model.named_parameters():
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training: