Fix auto-resume training from checkpoint (#9822)
* Fix auto-resume training from checkpoint * style fixes
This commit is contained in:
@@ -77,15 +77,19 @@ class TrainOutput(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||||
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d)+$")
|
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
|
||||||
|
|
||||||
|
|
||||||
def get_last_checkpoint(folder):
|
def get_last_checkpoint(folder):
|
||||||
content = os.listdir(folder)
|
content = os.listdir(folder)
|
||||||
checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(path)]
|
checkpoints = [
|
||||||
|
path
|
||||||
|
for path in content
|
||||||
|
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
|
||||||
|
]
|
||||||
if len(checkpoints) == 0:
|
if len(checkpoints) == 0:
|
||||||
return
|
return
|
||||||
return max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))
|
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
|
||||||
|
|
||||||
|
|
||||||
class EvaluationStrategy(ExplicitEnum):
|
class EvaluationStrategy(ExplicitEnum):
|
||||||
|
|||||||
Reference in New Issue
Block a user