From e53138a1b9e5bf722c57484860546376d503f4e8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 22 Sep 2020 18:26:37 -0400 Subject: [PATCH] [s2s] add src_lang kwarg for distributed eval (#7300) --- examples/seq2seq/run_distributed_eval.py | 17 ++++++++++++++++- src/transformers/tokenization_mbart.py | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index e8218e1917..fb66dd37ed 100644 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -38,6 +38,9 @@ def eval_data_dir( fp16=False, task="summarization", local_rank=None, + src_lang=None, + tgt_lang=None, + prefix="", **generate_kwargs, ) -> Dict: """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" @@ -57,6 +60,8 @@ def eval_data_dir( use_task_specific_params(model, task) # update config with task specific params if max_source_length is None: max_source_length = tokenizer.model_max_length + if prefix is None: + prefix = prefix or getattr(model.config, "prefix", "") or "" ds = Seq2SeqDataset( tokenizer, data_dir, @@ -64,7 +69,9 @@ def eval_data_dir( max_target_length=1024, type_path=type_path, n_obs=n_obs, - prefix=model.config.prefix, + src_lang=src_lang, + tgt_lang=tgt_lang, + prefix=prefix, ) # I set shuffle=True for a more accurate progress bar. # If all the longest samples are first, the prog bar estimate is too high at the beginning. @@ -118,6 +125,11 @@ def run_generate(): required=False, help="How long should master process wait for other processes to finish.", ) + parser.add_argument("--src_lang", type=str, default=None, required=False) + parser.add_argument("--tgt_lang", type=str, default=None, required=False) + parser.add_argument( + "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples" + ) parser.add_argument("--fp16", action="store_true") parser.add_argument("--debug", action="store_true") start_time = time.time() @@ -144,6 +156,9 @@ def run_generate(): local_rank=args.local_rank, n_obs=args.n_obs, max_source_length=args.max_source_length, + prefix=args.prefix, + src_lang=args.src_lang, + tgt_lang=args.tgt_lang, **generate_kwargs, ) diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py index eed9f0bf71..22a16f4a48 100644 --- a/src/transformers/tokenization_mbart.py +++ b/src/transformers/tokenization_mbart.py @@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer): truncation: bool = True, padding: str = "longest", return_tensors: str = "pt", + add_prefix_space: bool = False, # ignored **kwargs, ) -> BatchEncoding: """Prepare a batch that can be passed directly to an instance of MBartModel.