updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
This commit is contained in:
Shamane Siri
2021-06-09 00:42:49 +12:00
committed by GitHub
parent 70f88eeccc
commit e33085d648
5 changed files with 26 additions and 38 deletions

View File

@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "test":
def setup(self, stage):
if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
@@ -341,6 +341,7 @@ def generic_train(
args: argparse.Namespace,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
custom_ddp_plugin=None,
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
@@ -370,18 +371,17 @@ def generic_train(
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params,
)