Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
This commit is contained in:
@@ -23,6 +23,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
@@ -33,8 +34,6 @@ from .utils import logging
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
# Reissue warnings that are not the SAVE_STATE_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user