minor fixes in original RAG training (#12395)
This commit is contained in:
@@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
dirpath=output_dir,
|
dirpath=output_dir,
|
||||||
filename=exp,
|
filename=exp,
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="min",
|
mode="max",
|
||||||
save_top_k=3,
|
save_top_k=3,
|
||||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Create Ray actors only for rank 0.
|
# Create Ray actors only for rank 0.
|
||||||
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and (
|
||||||
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
"NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0
|
||||||
):
|
):
|
||||||
remote_cls = ray.remote(RayRetriever)
|
remote_cls = ray.remote(RayRetriever)
|
||||||
named_actors = [
|
named_actors = [
|
||||||
|
|||||||
Reference in New Issue
Block a user