[s2s] Fix t5 warning for distributed eval (#7487)
This commit is contained in:
@@ -42,8 +42,7 @@ def eval_data_dir(
|
|||||||
task="summarization",
|
task="summarization",
|
||||||
local_rank=None,
|
local_rank=None,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
src_lang=None,
|
dataset_kwargs: Dict = None,
|
||||||
tgt_lang=None,
|
|
||||||
prefix="",
|
prefix="",
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -78,9 +77,8 @@ def eval_data_dir(
|
|||||||
max_target_length=1024,
|
max_target_length=1024,
|
||||||
type_path=type_path,
|
type_path=type_path,
|
||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
src_lang=src_lang,
|
|
||||||
tgt_lang=tgt_lang,
|
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
**dataset_kwargs,
|
||||||
)
|
)
|
||||||
# I set shuffle=True for a more accurate progress bar.
|
# I set shuffle=True for a more accurate progress bar.
|
||||||
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
||||||
@@ -158,6 +156,11 @@ def run_generate():
|
|||||||
if intermediate_files:
|
if intermediate_files:
|
||||||
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
|
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
|
||||||
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
|
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
|
||||||
|
dataset_kwargs = {}
|
||||||
|
if args.src_lang is not None:
|
||||||
|
dataset_kwargs["src_lang"] = args.src_lang
|
||||||
|
if args.tgt_lang is not None:
|
||||||
|
dataset_kwargs["tgt_lang"] = args.tgt_lang
|
||||||
|
|
||||||
Path(args.save_dir).mkdir(exist_ok=True)
|
Path(args.save_dir).mkdir(exist_ok=True)
|
||||||
results, num_replicas = eval_data_dir(
|
results, num_replicas = eval_data_dir(
|
||||||
@@ -173,9 +176,7 @@ def run_generate():
|
|||||||
max_source_length=args.max_source_length,
|
max_source_length=args.max_source_length,
|
||||||
num_return_sequences=args.num_return_sequences,
|
num_return_sequences=args.num_return_sequences,
|
||||||
prefix=args.prefix,
|
prefix=args.prefix,
|
||||||
src_lang=args.src_lang,
|
dataset_kwargs=dataset_kwargs ** generate_kwargs,
|
||||||
tgt_lang=args.tgt_lang,
|
|
||||||
**generate_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank <= 0:
|
if args.local_rank <= 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user