Catch PyTorch warning when saving/loading scheduler (#7401)
This commit is contained in:
@@ -59,6 +59,8 @@ from .utils import logging
|
|||||||
_use_native_amp = False
|
_use_native_amp = False
|
||||||
_use_apex = False
|
_use_apex = False
|
||||||
|
|
||||||
|
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||||
|
|
||||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||||
from .file_utils import is_apex_available
|
from .file_utils import is_apex_available
|
||||||
@@ -99,6 +101,14 @@ if is_ray_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def reissue_pt_warnings(caught_warnings):
|
||||||
|
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||||
|
if len(caught_warnings) > 1:
|
||||||
|
for w in caught_warnings:
|
||||||
|
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||||
|
warnings.warn(w.message, w.category)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
"""
|
"""
|
||||||
@@ -643,7 +653,9 @@ class Trainer:
|
|||||||
self.optimizer.load_state_dict(
|
self.optimizer.load_state_dict(
|
||||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||||
)
|
)
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.args.fp16 and _use_apex:
|
if self.args.fp16 and _use_apex:
|
||||||
@@ -821,10 +833,14 @@ class Trainer:
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("saving_optimizer_states")
|
xm.rendezvous("saving_optimizer_states")
|
||||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
elif self.is_world_process_zero():
|
elif self.is_world_process_zero():
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user