updated with latest PL and Ray (#15653)
This commit is contained in:
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="max",
|
||||||
save_top_k=3,
|
save_top_k=3,
|
||||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
every_n_epochs=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
def training_step(self, batch, batch_idx) -> Dict:
|
def training_step(self, batch, batch_idx) -> Dict:
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
|
|
||||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
# tokens per batch
|
# tokens per batch
|
||||||
tgt_pad_token_id = (
|
tgt_pad_token_id = (
|
||||||
self.tokenizer.generator.pad_token_id
|
self.tokenizer.generator.pad_token_id
|
||||||
@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||||
# Connect to an existing Ray cluster.
|
# Connect to an existing Ray cluster.
|
||||||
try:
|
try:
|
||||||
ray.init(address=args.ray_address)
|
ray.init(address=args.ray_address, namespace="rag")
|
||||||
except (ConnectionError, ValueError):
|
except (ConnectionError, ValueError):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Connection to Ray cluster failed. Make sure a Ray"
|
"Connection to Ray cluster failed. Make sure a Ray"
|
||||||
|
|||||||
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--adafactor", action="store_true")
|
parser.add_argument("--adafactor", action="store_true")
|
||||||
|
|
||||||
|
|
||||||
|
class InitCallback(pl.Callback):
|
||||||
|
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
|
||||||
|
def on_sanity_check_start(self, trainer, pl_module):
|
||||||
|
if (
|
||||||
|
trainer.is_global_zero and trainer.global_rank == 0
|
||||||
|
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
|
||||||
|
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
def on_batch_end(self, trainer, pl_module):
|
def on_batch_end(self, trainer, pl_module):
|
||||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||||
@@ -368,19 +377,21 @@ def generic_train(
|
|||||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
train_params["precision"] = 16
|
train_params["precision"] = 16
|
||||||
train_params["amp_level"] = args.fp16_opt_level
|
# train_params["amp_level"] = args.fp16_opt_level
|
||||||
|
|
||||||
if args.gpus > 1:
|
if args.gpus > 1:
|
||||||
train_params["accelerator"] = "ddp"
|
train_params["accelerator"] = "auto" # "ddp"
|
||||||
|
train_params["strategy"] = "ddp"
|
||||||
|
|
||||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
||||||
|
train_params["devices"] = "auto"
|
||||||
|
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
|
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback] + [InitCallback()],
|
||||||
plugins=[custom_ddp_plugin],
|
# plugins=[custom_ddp_plugin],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
**train_params,
|
**train_params,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
|
|||||||
datasets >= 1.0.1
|
datasets >= 1.0.1
|
||||||
psutil >= 5.7.0
|
psutil >= 5.7.0
|
||||||
torch >= 1.4.0
|
torch >= 1.4.0
|
||||||
|
ray >= 1.10.0
|
||||||
|
pytorch-lightning >= 1.5.10
|
||||||
transformers
|
transformers
|
||||||
pytorch-lightning
|
GitPython
|
||||||
GitPython
|
|
||||||
Reference in New Issue
Block a user