Fix flaky SwitchTransformersModelTest::test_training_gradient (#35587)
* fix * Update tests/models/switch_transformers/test_modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -576,6 +576,8 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
|
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
|
||||||
model_split_percents = [0.5, 0.8, 0.9]
|
model_split_percents = [0.5, 0.8, 0.9]
|
||||||
|
# `SwitchTransformers` is a MOE in which not all experts will get gradients because they are not all used in a single forward pass
|
||||||
|
test_all_params_have_gradient = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SwitchTransformersModelTester(self)
|
self.model_tester = SwitchTransformersModelTester(self)
|
||||||
|
|||||||
@@ -221,6 +221,8 @@ class ModelTesterMixin:
|
|||||||
test_mismatched_shapes = True
|
test_mismatched_shapes = True
|
||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
|
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
|
||||||
|
test_all_params_have_gradient = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
has_attentions = True
|
has_attentions = True
|
||||||
_is_composite = False
|
_is_composite = False
|
||||||
@@ -895,9 +897,10 @@ class ModelTesterMixin:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
for k, v in model.named_parameters():
|
if self.test_all_params_have_gradient:
|
||||||
if v.requires_grad:
|
for k, v in model.named_parameters():
|
||||||
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
if v.requires_grad:
|
||||||
|
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||||
|
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
if not self.model_tester.is_training:
|
if not self.model_tester.is_training:
|
||||||
|
|||||||
Reference in New Issue
Block a user