[s2s] distributed eval in one command (#7124)
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import shutil
|
||||
import time
|
||||
from json import JSONDecodeError
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||
from .utils import (
|
||||
Seq2SeqDataset,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
lmap,
|
||||
load_json,
|
||||
parse_numeric_cl_kwargs,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
except ImportError:
|
||||
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
lmap,
|
||||
load_json,
|
||||
parse_numeric_cl_kwargs,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
|
||||
|
||||
def eval_data_dir(
|
||||
@@ -30,7 +50,6 @@ def eval_data_dir(
|
||||
type_path="val",
|
||||
n_obs=None,
|
||||
fp16=False,
|
||||
save_source=False,
|
||||
num_beams: int = 4,
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
@@ -62,7 +81,7 @@ def eval_data_dir(
|
||||
n_obs=n_obs,
|
||||
prefix=model.config.prefix,
|
||||
)
|
||||
sampler = ds.make_sortish_sampler(bs, distributed=True)
|
||||
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
|
||||
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
||||
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
||||
results = []
|
||||
@@ -75,23 +94,19 @@ def eval_data_dir(
|
||||
)
|
||||
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
||||
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
||||
if save_source:
|
||||
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
|
||||
ids = batch["ids"]
|
||||
for i in range(len(labels)):
|
||||
label, pred = labels[i], preds[i]
|
||||
if save_source:
|
||||
results.append(dict(pred=pred, label=label, source=docs[i]))
|
||||
else:
|
||||
results.append(dict(pred=pred, label=label))
|
||||
results.append(dict(pred=pred, label=label, id=ids[i].item()))
|
||||
save_json(results, save_path)
|
||||
return results
|
||||
return results, sampler.num_replicas
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
|
||||
)
|
||||
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
|
||||
parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
@@ -113,17 +128,31 @@ def run_generate():
|
||||
parser.add_argument(
|
||||
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sync_timeout",
|
||||
type=int,
|
||||
default=600,
|
||||
required=False,
|
||||
help="How long should master process wait for other processes to finish.",
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--save_source", action="store_true")
|
||||
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
start_time = time.time()
|
||||
args, rest = parser.parse_known_args()
|
||||
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
||||
if generate_kwargs:
|
||||
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
||||
json_save_dir = Path(args.save_dir + "_tmp")
|
||||
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
|
||||
intermediate_files = list(json_save_dir.glob("rank_*.json"))
|
||||
if intermediate_files:
|
||||
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
|
||||
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
|
||||
|
||||
Path(args.save_dir).mkdir(exist_ok=True)
|
||||
eval_data_dir(
|
||||
args.input_path,
|
||||
args.save_dir,
|
||||
results, num_replicas = eval_data_dir(
|
||||
args.data_dir,
|
||||
json_save_dir,
|
||||
args.model_name,
|
||||
type_path=args.type_path,
|
||||
batch_size=args.bs,
|
||||
@@ -131,11 +160,64 @@ def run_generate():
|
||||
task=args.task,
|
||||
local_rank=args.local_rank,
|
||||
n_obs=args.n_obs,
|
||||
save_source=args.save_source,
|
||||
max_source_length=args.max_source_length,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
if args.local_rank <= 0:
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
|
||||
preds, labels = combine_partial_results(partial_results)
|
||||
# Calculate metrics, save metrics, and save _generations.txt
|
||||
calc_bleu = "translation" in args.task
|
||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||
metric_name = "bleu" if calc_bleu else "rouge"
|
||||
metrics: Dict = score_fn(preds, labels)
|
||||
metrics["n_obs"] = len(preds)
|
||||
runtime = time.time() - start_time
|
||||
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
|
||||
# TODO(@stas00): add whatever metadata to metrics
|
||||
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
|
||||
save_json(metrics, metrics_save_path)
|
||||
print(metrics)
|
||||
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
|
||||
if args.debug:
|
||||
write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
|
||||
else:
|
||||
shutil.rmtree(json_save_dir)
|
||||
|
||||
|
||||
def combine_partial_results(partial_results) -> Tuple[List, List]:
|
||||
"""Concatenate partial results into one file, then sort it by id."""
|
||||
records = []
|
||||
for partial_result in partial_results:
|
||||
records.extend(partial_result)
|
||||
records = list(sorted(records, key=lambda x: x["id"]))
|
||||
preds = [x["pred"] for x in records]
|
||||
labels = [x["label"] for x in records]
|
||||
return preds, labels
|
||||
|
||||
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||
# WAIT FOR lots of .json files
|
||||
start_wait = time.time()
|
||||
logger.info("waiting for all nodes to finish")
|
||||
json_data = None
|
||||
while (time.time() - start_wait) < timeout:
|
||||
json_files = list(save_dir.glob("rank_*.json"))
|
||||
if len(json_files) < num_replicas:
|
||||
continue
|
||||
try:
|
||||
# make sure all json files are fully saved
|
||||
json_data = lmap(load_json, json_files)
|
||||
return json_data
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise TimeoutError("Rank 0 gave up on waiting for other processes")
|
||||
# Unreachable
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage for MT:
|
||||
|
||||
Reference in New Issue
Block a user