[s2s] add src_lang kwarg for distributed eval (#7300)
This commit is contained in:
@@ -38,6 +38,9 @@ def eval_data_dir(
|
|||||||
fp16=False,
|
fp16=False,
|
||||||
task="summarization",
|
task="summarization",
|
||||||
local_rank=None,
|
local_rank=None,
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
|
prefix="",
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||||
@@ -57,6 +60,8 @@ def eval_data_dir(
|
|||||||
use_task_specific_params(model, task) # update config with task specific params
|
use_task_specific_params(model, task) # update config with task specific params
|
||||||
if max_source_length is None:
|
if max_source_length is None:
|
||||||
max_source_length = tokenizer.model_max_length
|
max_source_length = tokenizer.model_max_length
|
||||||
|
if prefix is None:
|
||||||
|
prefix = prefix or getattr(model.config, "prefix", "") or ""
|
||||||
ds = Seq2SeqDataset(
|
ds = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir,
|
data_dir,
|
||||||
@@ -64,7 +69,9 @@ def eval_data_dir(
|
|||||||
max_target_length=1024,
|
max_target_length=1024,
|
||||||
type_path=type_path,
|
type_path=type_path,
|
||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
prefix=model.config.prefix,
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
# I set shuffle=True for a more accurate progress bar.
|
# I set shuffle=True for a more accurate progress bar.
|
||||||
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
||||||
@@ -118,6 +125,11 @@ def run_generate():
|
|||||||
required=False,
|
required=False,
|
||||||
help="How long should master process wait for other processes to finish.",
|
help="How long should master process wait for other processes to finish.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--src_lang", type=str, default=None, required=False)
|
||||||
|
parser.add_argument("--tgt_lang", type=str, default=None, required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
|
||||||
|
)
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -144,6 +156,9 @@ def run_generate():
|
|||||||
local_rank=args.local_rank,
|
local_rank=args.local_rank,
|
||||||
n_obs=args.n_obs,
|
n_obs=args.n_obs,
|
||||||
max_source_length=args.max_source_length,
|
max_source_length=args.max_source_length,
|
||||||
|
prefix=args.prefix,
|
||||||
|
src_lang=args.src_lang,
|
||||||
|
tgt_lang=args.tgt_lang,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
truncation: bool = True,
|
truncation: bool = True,
|
||||||
padding: str = "longest",
|
padding: str = "longest",
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
|
add_prefix_space: bool = False, # ignored
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||||
|
|||||||
Reference in New Issue
Block a user