fix double wrapping + test (#10583)

This commit is contained in:
Stas Bekman
2021-03-08 07:15:55 -08:00
committed by GitHub
parent b880508440
commit f882966004
2 changed files with 17 additions and 0 deletions

View File

@@ -738,6 +738,10 @@ class Trainer:
if self.deepspeed:
return self.deepspeed
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)