From 82dd6c14bb811379d126acb801b8f7fd857eb2cc Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 9 Jan 2025 15:36:22 +0100 Subject: [PATCH] Fix flaky `SwitchTransformersModelTest::test_training_gradient` (#35587) * fix * Update tests/models/switch_transformers/test_modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: ydshieh Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../test_modeling_switch_transformers.py | 2 ++ tests/test_modeling_common.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 7adb1f40c6..32597f8ce2 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -576,6 +576,8 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel test_torchscript = False # The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # `SwitchTransformers` is a MOE in which not all experts will get gradients because they are not all used in a single forward pass + test_all_params_have_gradient = False def setUp(self): self.model_tester = SwitchTransformersModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1fe8f043dc..c29a15efd3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -221,6 +221,8 @@ class ModelTesterMixin: test_mismatched_shapes = True test_missing_keys = True test_model_parallel = False + # Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models) + test_all_params_have_gradient = True is_encoder_decoder = False has_attentions = True _is_composite = False @@ -895,9 +897,10 @@ class ModelTesterMixin: loss.backward() optimizer.step() - for k, v in model.named_parameters(): - if v.requires_grad: - self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + if self.test_all_params_have_gradient: + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") def test_training(self): if not self.model_tester.is_training: