updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
This commit is contained in:
Shamane Siri
2021-06-09 00:42:49 +12:00
committed by GitHub
parent 70f88eeccc
commit e33085d648
5 changed files with 26 additions and 38 deletions

View File

@@ -1,5 +1,4 @@
import logging import logging
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
) )
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp), dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="min",
save_top_k=3, save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch. period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
) )

View File

@@ -3,7 +3,6 @@ import random
import ray import ray
from transformers import RagConfig, RagRetriever, RagTokenizer from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex from transformers.models.rag.retrieval_rag import CustomHFIndex
@@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
@classmethod @classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder question_encoder_tokenizer = rag_tokenizer.question_encoder

View File

@@ -13,8 +13,8 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator import torch.distributed as torch_distrib
from pytorch_lightning.cluster_environments import TorchElasticEnvironment from pytorch_lightning.plugins.training_type import DDPPlugin
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import ( from transformers import (
@@ -36,7 +36,6 @@ if is_ray_available():
import ray import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from callbacks_rag import ( # noqa: E402 # isort:skipq from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback, get_checkpoint_callback,
get_early_stopping_callback, get_early_stopping_callback,
@@ -74,27 +73,19 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule` class CustomDDP(DDPPlugin):
# is no longer used, and is moved into DDPAccelerator instead. def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
# We override DDPAccelerator to add our custom logic for initializing the module = self.model
# retriever. global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
if not torch.distributed.is_initialized():
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model: if module.is_rag_model:
self.distributed_port = module.hparams.distributed_port
if module.distributed_retriever == "pytorch": if module.distributed_retriever == "pytorch":
module.model.rag.retriever.init_retrieval(self.distributed_port) module.model.rag.retriever.init_retrieval(self.distributed_port)
elif module.distributed_retriever == "ray" and global_rank == 0: elif module.distributed_retriever == "ray" and global_rank == 0:
@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback, early_stopping_callback=es_callback,
logger=training_logger, logger=training_logger,
accelerator=CustomAccel() if args.gpus > 1 else None, custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None, profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
) )
pickle_save(model.hparams, model.output_dir / "hparams.pkl") pickle_save(model.hparams, model.output_dir / "hparams.pkl")

View File

@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode): def setup(self, stage):
if mode == "test": if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset) self.dataset_size = len(self.test_dataloader().dataset)
else: else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
@@ -341,6 +341,7 @@ def generic_train(
args: argparse.Namespace, args: argparse.Namespace,
early_stopping_callback=None, early_stopping_callback=None,
logger=True, # can pass WandbLogger() here logger=True, # can pass WandbLogger() here
custom_ddp_plugin=None,
extra_callbacks=[], extra_callbacks=[],
checkpoint_callback=None, checkpoint_callback=None,
logging_callback=None, logging_callback=None,
@@ -370,18 +371,17 @@ def generic_train(
train_params["amp_level"] = args.fp16_opt_level train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1: if args.gpus > 1:
train_params["distributed_backend"] = "ddp" train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
weights_summary=None, weights_summary=None,
callbacks=[logging_callback] + extra_callbacks, callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
logger=logger, logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params, **train_params,
) )

View File

@@ -3,5 +3,5 @@ datasets >= 1.0.1
psutil >= 5.7.0 psutil >= 5.7.0
torch >= 1.4.0 torch >= 1.4.0
transformers transformers
pytorch-lightning==1.0.4 pytorch-lightning==1.3.1
GitPython GitPython