From 5257818e685c89fb473504afb5cbcd7b34301811 Mon Sep 17 00:00:00 2001 From: Shamane Siri Date: Wed, 30 Jun 2021 00:39:48 +1200 Subject: [PATCH] minor fixes in original RAG training (#12395) --- examples/research_projects/rag/callbacks_rag.py | 2 +- examples/research_projects/rag/finetune_rag.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/rag/callbacks_rag.py b/examples/research_projects/rag/callbacks_rag.py index 3d8425e612..e9eda20de3 100644 --- a/examples/research_projects/rag/callbacks_rag.py +++ b/examples/research_projects/rag/callbacks_rag.py @@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric): dirpath=output_dir, filename=exp, monitor=f"val_{metric}", - mode="min", + mode="max", save_top_k=3, period=1, # maybe save a checkpoint every time val is run, not just end of epoch. ) diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py index b5ccaa228c..a1721623dd 100644 --- a/examples/research_projects/rag/finetune_rag.py +++ b/examples/research_projects/rag/finetune_rag.py @@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule: raise # Create Ray actors only for rank 0. - if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and ( - "NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0 + if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and ( + "NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0 ): remote_cls = ray.remote(RayRetriever) named_actors = [