Reload checkpoint (#7984)

* Fix checkpoint loading in Trainer

* Fix typo
This commit is contained in:
Sylvain Gugger
2020-10-22 15:48:52 -04:00
committed by GitHub
parent 467573ddde
commit 5ae935d233
3 changed files with 35 additions and 17 deletions

View File

@@ -23,6 +23,7 @@ from typing import List, Optional, Union
import numpy as np
import torch
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
@@ -33,8 +34,6 @@ from .utils import logging
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
logger = logging.get_logger(__name__)
@@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
# Reissue warnings that are not the SAVE_STATE_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
warnings.warn(w.message, w.category)