updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
This commit is contained in:
Shamane Siri
2021-06-09 00:42:49 +12:00
committed by GitHub
parent 70f88eeccc
commit e33085d648
5 changed files with 26 additions and 38 deletions

View File

@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
import numpy as np
@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="max",
mode="min",
save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)