fix double wrapping + test (#10583)
This commit is contained in:
@@ -738,6 +738,10 @@ class Trainer:
|
|||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
return self.deepspeed
|
return self.deepspeed
|
||||||
|
|
||||||
|
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||||
|
if unwrap_model(model) is not model:
|
||||||
|
return model
|
||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
if self.use_apex and training:
|
if self.use_apex and training:
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|||||||
@@ -574,6 +574,19 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model)
|
self.check_trained_model(trainer.model)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_run_seq2seq_double_train_wrap_once(self):
|
||||||
|
# test that we don't wrap the model more than once
|
||||||
|
# since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
|
||||||
|
# example DataParallel(DataParallel(model))
|
||||||
|
|
||||||
|
trainer = get_regression_trainer()
|
||||||
|
trainer.train()
|
||||||
|
model_wrapped_before = trainer.model_wrapped
|
||||||
|
trainer.train()
|
||||||
|
model_wrapped_after = trainer.model_wrapped
|
||||||
|
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
|
||||||
|
|
||||||
def test_can_resume_training(self):
|
def test_can_resume_training(self):
|
||||||
if torch.cuda.device_count() > 2:
|
if torch.cuda.device_count() > 2:
|
||||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
|||||||
Reference in New Issue
Block a user