[s2s] distributed eval cleanup (#7186)

This commit is contained in:
Sam Shleifer
2020-09-16 15:38:37 -04:00
committed by GitHub
parent 3babef815c
commit 0203ad43bc
4 changed files with 47 additions and 31 deletions

View File

@@ -4,7 +4,7 @@ import time
from json import JSONDecodeError
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
@@ -22,7 +22,7 @@ try:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
@@ -34,7 +34,7 @@ except ImportError:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
@@ -50,7 +50,6 @@ def eval_data_dir(
type_path="val",
n_obs=None,
fp16=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
**generate_kwargs,
@@ -81,23 +80,21 @@ def eval_data_dir(
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs,
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
ids = batch["ids"]
for i in range(len(labels)):
label, pred = labels[i], preds[i]
results.append(dict(pred=pred, label=label, id=ids[i].item()))
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
save_json(results, save_path)
return results, sampler.num_replicas
@@ -139,8 +136,8 @@ def run_generate():
parser.add_argument("--debug", action="store_true")
start_time = time.time()
args, rest = parser.parse_known_args()
generate_kwargs = parse_numeric_cl_kwargs(rest)
if generate_kwargs:
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
if generate_kwargs and args.local_rank <= 0:
print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
@@ -168,7 +165,10 @@ def run_generate():
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
preds, labels = combine_partial_results(partial_results)
preds = combine_partial_results(partial_results)
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge
@@ -179,7 +179,7 @@ def run_generate():
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
save_json(metrics, metrics_save_path)
save_json(metrics, metrics_save_path, indent=None)
print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug:
@@ -188,15 +188,14 @@ def run_generate():
shutil.rmtree(json_save_dir)
def combine_partial_results(partial_results) -> Tuple[List, List]:
def combine_partial_results(partial_results) -> List:
"""Concatenate partial results into one file, then sort it by id."""
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
return preds, labels
return preds
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: