updated with latest PL and Ray (#15653)
This commit is contained in:
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--adafactor", action="store_true")
|
||||
|
||||
|
||||
class InitCallback(pl.Callback):
|
||||
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
if (
|
||||
trainer.is_global_zero and trainer.global_rank == 0
|
||||
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
|
||||
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
@@ -368,19 +377,21 @@ def generic_train(
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
# train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["accelerator"] = "ddp"
|
||||
train_params["accelerator"] = "auto" # "ddp"
|
||||
train_params["strategy"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
||||
train_params["devices"] = "auto"
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
|
||||
plugins=[custom_ddp_plugin],
|
||||
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback] + [InitCallback()],
|
||||
# plugins=[custom_ddp_plugin],
|
||||
logger=logger,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user