updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)
* updated the original RAG implementation to be compatible with the latest PL version * updated the requirements.txt file * execute make style * code quality test * code quality * conflix resolved in requirement.txt * code quality * changed the MyDDP class name to CustomDDP
This commit is contained in:
@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "test":
|
||||
def setup(self, stage):
|
||||
if stage == "test":
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
@@ -341,6 +341,7 @@ def generic_train(
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
custom_ddp_plugin=None,
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
@@ -370,18 +371,17 @@ def generic_train(
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
train_params["accelerator"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
|
||||
plugins=[custom_ddp_plugin],
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user