updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)
* updated the original RAG implementation to be compatible with the latest PL version * updated the requirements.txt file * execute make style * code quality test * code quality * conflix resolved in requirement.txt * code quality * changed the MyDDP class name to CustomDDP
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
)
|
)
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=os.path.join(output_dir, exp),
|
dirpath=output_dir,
|
||||||
|
filename=exp,
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="min",
|
||||||
save_top_k=3,
|
save_top_k=3,
|
||||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import random
|
|||||||
|
|
||||||
import ray
|
import ray
|
||||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||||
from transformers.file_utils import requires_datasets, requires_faiss
|
|
||||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||||
requires_datasets(cls)
|
|
||||||
requires_faiss(cls)
|
|
||||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import numpy as np
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
import torch.distributed as torch_distrib
|
||||||
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
from pytorch_lightning.plugins.training_type import DDPPlugin
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -36,7 +36,6 @@ if is_ray_available():
|
|||||||
import ray
|
import ray
|
||||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||||
|
|
||||||
|
|
||||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||||
get_checkpoint_callback,
|
get_checkpoint_callback,
|
||||||
get_early_stopping_callback,
|
get_early_stopping_callback,
|
||||||
@@ -74,27 +73,19 @@ class AttrDict(dict):
|
|||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
class CustomDDP(DDPPlugin):
|
||||||
# is no longer used, and is moved into DDPAccelerator instead.
|
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
|
||||||
# We override DDPAccelerator to add our custom logic for initializing the
|
module = self.model
|
||||||
# retriever.
|
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
|
||||||
|
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
|
||||||
|
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||||
|
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
|
||||||
|
|
||||||
|
|
||||||
class CustomAccel(DDPAccelerator):
|
|
||||||
def __init__(self, trainer=None, **kwargs):
|
|
||||||
# Trainer is set later.
|
|
||||||
super().__init__(trainer, **kwargs)
|
|
||||||
|
|
||||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
|
||||||
logger.info("Custom init_ddp_connection.")
|
|
||||||
module = self.trainer.model
|
|
||||||
if self.cluster_environment is None:
|
|
||||||
self.cluster_environment = TorchElasticEnvironment()
|
|
||||||
self.distributed_port = module.hparams.distributed_port
|
|
||||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
|
||||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
|
||||||
if module.is_rag_model:
|
if module.is_rag_model:
|
||||||
|
self.distributed_port = module.hparams.distributed_port
|
||||||
if module.distributed_retriever == "pytorch":
|
if module.distributed_retriever == "pytorch":
|
||||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||||
elif module.distributed_retriever == "ray" and global_rank == 0:
|
elif module.distributed_retriever == "ray" and global_rank == 0:
|
||||||
@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||||
early_stopping_callback=es_callback,
|
early_stopping_callback=es_callback,
|
||||||
logger=training_logger,
|
logger=training_logger,
|
||||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
|
||||||
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||||
)
|
)
|
||||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||||
|
|||||||
@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||||
|
|
||||||
def setup(self, mode):
|
def setup(self, stage):
|
||||||
if mode == "test":
|
if stage == "test":
|
||||||
self.dataset_size = len(self.test_dataloader().dataset)
|
self.dataset_size = len(self.test_dataloader().dataset)
|
||||||
else:
|
else:
|
||||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||||
@@ -341,6 +341,7 @@ def generic_train(
|
|||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
early_stopping_callback=None,
|
early_stopping_callback=None,
|
||||||
logger=True, # can pass WandbLogger() here
|
logger=True, # can pass WandbLogger() here
|
||||||
|
custom_ddp_plugin=None,
|
||||||
extra_callbacks=[],
|
extra_callbacks=[],
|
||||||
checkpoint_callback=None,
|
checkpoint_callback=None,
|
||||||
logging_callback=None,
|
logging_callback=None,
|
||||||
@@ -370,18 +371,17 @@ def generic_train(
|
|||||||
train_params["amp_level"] = args.fp16_opt_level
|
train_params["amp_level"] = args.fp16_opt_level
|
||||||
|
|
||||||
if args.gpus > 1:
|
if args.gpus > 1:
|
||||||
train_params["distributed_backend"] = "ddp"
|
train_params["accelerator"] = "ddp"
|
||||||
|
|
||||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
||||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
|
||||||
|
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
callbacks=[logging_callback] + extra_callbacks,
|
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
|
||||||
|
plugins=[custom_ddp_plugin],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
checkpoint_callback=checkpoint_callback,
|
|
||||||
**train_params,
|
**train_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ datasets >= 1.0.1
|
|||||||
psutil >= 5.7.0
|
psutil >= 5.7.0
|
||||||
torch >= 1.4.0
|
torch >= 1.4.0
|
||||||
transformers
|
transformers
|
||||||
pytorch-lightning==1.0.4
|
pytorch-lightning==1.3.1
|
||||||
GitPython
|
GitPython
|
||||||
Reference in New Issue
Block a user