[s2s] distributed eval allows num_return_sequences > 1 (#7254)
This commit is contained in:
@@ -17,6 +17,7 @@ from utils import (
|
||||
Seq2SeqDataset,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
chunks,
|
||||
lmap,
|
||||
load_json,
|
||||
parse_numeric_n_bool_cl_kwargs,
|
||||
@@ -40,6 +41,7 @@ def eval_data_dir(
|
||||
fp16=False,
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
num_return_sequences=1,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
@@ -56,10 +58,15 @@ def eval_data_dir(
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
|
||||
if fp16:
|
||||
model = model.half()
|
||||
# determine if we need to increase num_beams
|
||||
use_task_specific_params(model, task) # update config with task specific params
|
||||
num_beams = generate_kwargs.pop("num_beams", model.config.num_beams) # AttributeError risk?
|
||||
if num_return_sequences > num_beams:
|
||||
num_beams = num_return_sequences
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
||||
use_task_specific_params(model, task) # update config with task specific params
|
||||
|
||||
if max_source_length is None:
|
||||
max_source_length = tokenizer.model_max_length
|
||||
if prefix is None:
|
||||
@@ -84,10 +91,14 @@ def eval_data_dir(
|
||||
summaries = model.generate(
|
||||
input_ids=batch["input_ids"].to(model.device),
|
||||
attention_mask=batch["attention_mask"].to(model.device),
|
||||
num_return_sequences=num_return_sequences,
|
||||
num_beams=num_beams,
|
||||
**generate_kwargs,
|
||||
)
|
||||
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
ids = batch["ids"]
|
||||
if num_return_sequences > 1:
|
||||
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
|
||||
for i, pred in enumerate(preds):
|
||||
results.append(dict(pred=pred, id=ids[i].item()))
|
||||
save_json(results, save_path)
|
||||
@@ -110,7 +121,6 @@ def run_generate():
|
||||
parser.add_argument(
|
||||
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
|
||||
)
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument(
|
||||
@@ -120,6 +130,9 @@ def run_generate():
|
||||
parser.add_argument(
|
||||
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sync_timeout",
|
||||
type=int,
|
||||
@@ -158,6 +171,7 @@ def run_generate():
|
||||
local_rank=args.local_rank,
|
||||
n_obs=args.n_obs,
|
||||
max_source_length=args.max_source_length,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
prefix=args.prefix,
|
||||
src_lang=args.src_lang,
|
||||
tgt_lang=args.tgt_lang,
|
||||
@@ -169,6 +183,11 @@ def run_generate():
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
|
||||
preds = combine_partial_results(partial_results)
|
||||
if args.num_return_sequences > 1:
|
||||
save_path = save_dir.joinpath("pseudolabel_results.json")
|
||||
print(f"Saving aggregated results at {save_path}, intermediate in {json_save_dir}/")
|
||||
save_json(preds, save_path)
|
||||
return
|
||||
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
|
||||
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user