fix double wrapping + test (#10583)

This commit is contained in:
Stas Bekman
2021-03-08 07:15:55 -08:00
committed by GitHub
parent b880508440
commit f882966004
2 changed files with 17 additions and 0 deletions

View File

@@ -574,6 +574,19 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer.train()
self.check_trained_model(trainer.model)
@require_torch_multi_gpu
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
# since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
# example DataParallel(DataParallel(model))
trainer = get_regression_trainer()
trainer.train()
model_wrapped_before = trainer.model_wrapped
trainer.train()
model_wrapped_after = trainer.model_wrapped
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
def test_can_resume_training(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of