Catch PyTorch warning when saving/loading scheduler (#7401)
This commit is contained in:
@@ -59,6 +59,8 @@ from .utils import logging
|
||||
_use_native_amp = False
|
||||
_use_apex = False
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
from .file_utils import is_apex_available
|
||||
@@ -99,6 +101,14 @@ if is_ray_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
@@ -643,7 +653,9 @@ class Trainer:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
model = self.model
|
||||
if self.args.fp16 and _use_apex:
|
||||
@@ -821,10 +833,14 @@ class Trainer:
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif self.is_world_process_zero():
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
|
||||
Reference in New Issue
Block a user