minor fixes in original RAG training (#12395)
This commit is contained in:
@@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
dirpath=output_dir,
|
||||
filename=exp,
|
||||
monitor=f"val_{metric}",
|
||||
mode="min",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
|
||||
@@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
raise
|
||||
|
||||
# Create Ray actors only for rank 0.
|
||||
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
||||
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
||||
if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and (
|
||||
"NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0
|
||||
):
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
named_actors = [
|
||||
|
||||
Reference in New Issue
Block a user