[examples/seq2seq] fix PL deprecation warning (#8577)
* fix deprecation warning * fix
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -98,7 +97,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
|
|||||||
)
|
)
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=os.path.join(output_dir, exp),
|
dirpath=output_dir,
|
||||||
|
filename=exp,
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="min" if "loss" in metric else "max",
|
mode="min" if "loss" in metric else "max",
|
||||||
save_top_k=save_top_k,
|
save_top_k=save_top_k,
|
||||||
|
|||||||
Reference in New Issue
Block a user