updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)
* updated the original RAG implementation to be compatible with the latest PL version * updated the requirements.txt file * execute make style * code quality test * code quality * conflix resolved in requirement.txt * code quality * changed the MyDDP class name to CustomDDP
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
dirpath=output_dir,
|
||||
filename=exp,
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user