Fix rag finetuning + add finetuning test (#8585)
* replace init_ddp_connection for index init * style * add finetune test * add test data * move generate tensors to device * add test on EM metric * style * allow multi process test * keep gloo process group for retrieval * add multi-gpu test * use custom accelerator * clean test finetune * minor * style * style * typo * use python call instead of imported main fumction * return_dict fix in modeling_rag * use float32 in retrieval * store as float32 as well in the custom knowledge dataset example * style * rename to finetune_rag * style * update readme * rename utils and callbacks to utils_rag and callbacks_rag * fix test * patrick's comments * generate dummy data in the finetue test script * remove dummy data files * style
This commit is contained in:
@@ -384,6 +384,8 @@ def generic_train(
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
|
||||
Reference in New Issue
Block a user