[s2s] add src_lang kwarg for distributed eval (#7300)
This commit is contained in:
@@ -38,6 +38,9 @@ def eval_data_dir(
|
||||
fp16=False,
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||
@@ -57,6 +60,8 @@ def eval_data_dir(
|
||||
use_task_specific_params(model, task) # update config with task specific params
|
||||
if max_source_length is None:
|
||||
max_source_length = tokenizer.model_max_length
|
||||
if prefix is None:
|
||||
prefix = prefix or getattr(model.config, "prefix", "") or ""
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir,
|
||||
@@ -64,7 +69,9 @@ def eval_data_dir(
|
||||
max_target_length=1024,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
prefix=model.config.prefix,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
prefix=prefix,
|
||||
)
|
||||
# I set shuffle=True for a more accurate progress bar.
|
||||
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
||||
@@ -118,6 +125,11 @@ def run_generate():
|
||||
required=False,
|
||||
help="How long should master process wait for other processes to finish.",
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default=None, required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default=None, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
start_time = time.time()
|
||||
@@ -144,6 +156,9 @@ def run_generate():
|
||||
local_rank=args.local_rank,
|
||||
n_obs=args.n_obs,
|
||||
max_source_length=args.max_source_length,
|
||||
prefix=args.prefix,
|
||||
src_lang=args.src_lang,
|
||||
tgt_lang=args.tgt_lang,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
add_prefix_space: bool = False, # ignored
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
|
||||
Reference in New Issue
Block a user