From 7563d5a3cf8d158fce3c83db55468240a6badb6f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 28 Sep 2020 08:20:10 -0400 Subject: [PATCH] Catch PyTorch warning when saving/loading scheduler (#7401) --- src/transformers/trainer.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3ce676854d..7d9a093d6c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,6 +59,8 @@ from .utils import logging _use_native_amp = False _use_apex = False +PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." + # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex if version.parse(torch.__version__) < version.parse("1.6"): from .file_utils import is_apex_available @@ -99,6 +101,14 @@ if is_ray_available(): logger = logging.get_logger(__name__) +def reissue_pt_warnings(caught_warnings): + # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING + if len(caught_warnings) > 1: + for w in caught_warnings: + if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING: + warnings.warn(w.message, w.category) + + @contextmanager def torch_distributed_zero_first(local_rank: int): """ @@ -643,7 +653,9 @@ class Trainer: self.optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) ) - self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) + reissue_pt_warnings(caught_warnings) model = self.model if self.args.fp16 and _use_apex: @@ -821,10 +833,14 @@ class Trainer: if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) elif self.is_world_process_zero(): torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) epoch_pbar.update(1) if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: