minor fixes in original RAG training (#12395)

This commit is contained in:
Shamane Siri
2021-06-30 00:39:48 +12:00
committed by GitHub
parent e3f39a2952
commit 5257818e68
2 changed files with 3 additions and 3 deletions

View File

@@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric):
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="min",
mode="max",
save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)

View File

@@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise
# Create Ray actors only for rank 0.
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and (
"NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0
):
remote_cls = ray.remote(RayRetriever)
named_actors = [