From e7f8d2ab6476782a088964ba2eb58c5d06db2f20 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 13 Sep 2020 17:28:18 -0400 Subject: [PATCH] [s2s] two stage run_distributed_eval.py (#7105) --- .../seq2seq/aggregate_distributed_results.py | 46 ++++++ examples/seq2seq/run_distributed_eval.py | 139 ++++++++++++++++++ examples/seq2seq/utils.py | 7 + 3 files changed, 192 insertions(+) create mode 100644 examples/seq2seq/aggregate_distributed_results.py create mode 100644 examples/seq2seq/run_distributed_eval.py diff --git a/examples/seq2seq/aggregate_distributed_results.py b/examples/seq2seq/aggregate_distributed_results.py new file mode 100644 index 0000000000..5e6a8563a2 --- /dev/null +++ b/examples/seq2seq/aggregate_distributed_results.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import fire + + +try: + from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file +except ImportError: + from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file + + +def combine_partial_results( + result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False +): + """Write first n lines of each file f in src_dir to dest_dir/f """ + src_dir = Path(result_dir) + save_dir = Path(save_dir) + save_dir.mkdir(exist_ok=True) + paths_to_combine = list(src_dir.glob("rank*.json")) + records = [] + for partial_result in paths_to_combine: + records.extend(load_json(partial_result)) + preds = [x["pred"] for x in records] + labels = [x["label"] for x in records] + score_fn = calculate_bleu if calc_bleu else calculate_rouge + metrics = score_fn(preds, labels) + save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json + print(metrics) + if just_metrics: + return + + if save_prefix is None: + save_prefix = "generated" + print("using generated as prefix") + + tgt_path = save_dir.joinpath(f"{save_prefix}.target") + write_txt_file(labels, tgt_path) + pred_path = save_dir.joinpath(f"{save_prefix}.pred_target") + write_txt_file(preds, pred_path) + if "source" in records[0]: + src_path = save_dir.joinpath(f"{save_prefix}.source") + write_txt_file([x["source"] for x in records], src_path) + + +if __name__ == "__main__": + fire.Fire(combine_partial_results) diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py new file mode 100644 index 0000000000..ea1b3ed524 --- /dev/null +++ b/examples/seq2seq/run_distributed_eval.py @@ -0,0 +1,139 @@ +import argparse +import warnings +from logging import getLogger +from pathlib import Path +from typing import Dict + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + +logger = getLogger(__name__) + +try: + from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params +except ImportError: + from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params + +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def eval_data_dir( + data_dir, + save_dir: str, + model_name: str, + bs: int = 8, + max_source_length: int = 1024, + type_path="val", + n_obs=None, + fp16=False, + save_source=False, + num_beams: int = 4, + task="summarization", + local_rank=None, + **generate_kwargs, +) -> Dict: + """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" + model_name = str(model_name) + assert local_rank is not None + torch.distributed.init_process_group(backend="nccl", rank=local_rank) + + save_dir = Path(save_dir) + save_path = save_dir.joinpath(f"rank_{local_rank}_output.json") + torch.cuda.set_device(local_rank) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() + if fp16: + model = model.half() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. + use_task_specific_params(model, task) # update config with task specific params + ds = Seq2SeqDataset( + tokenizer, + data_dir, + max_source_length, + max_target_length=1024, + type_path=type_path, + n_obs=n_obs, + prefix=model.config.prefix, + ) + sampler = ds.make_sortish_sampler(bs, distributed=True) + data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) + dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode + results = [] + for batch in tqdm(data_loader): + summaries = model.generate( + input_ids=batch["input_ids"].to(model.device), + attention_mask=batch["attention_mask"].to(model.device), + num_beams=num_beams, + **generate_kwargs, + ) + preds = tokenizer.batch_decode(summaries, **dec_kwargs) + labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs) + if save_source: + docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs) + for i in range(len(labels)): + label, pred = labels[i], preds[i] + if save_source: + results.append(dict(pred=pred, label=label, source=docs[i])) + else: + results.append(dict(pred=pred, label=label)) + save_json(results, save_path) + return results + + +def run_generate(): + parser = argparse.ArgumentParser( + epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" + ) + parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source") + parser.add_argument( + "--model_name", + type=str, + help="like facebook/bart-large-cnn,t5-base, etc.", + default="sshleifer/distilbart-xsum-12-3", + ) + parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") + parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test") + parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") + parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics") + parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") + parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument( + "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch" + ) + + parser.add_argument( + "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." + ) + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--save_source", action="store_true") + + args, rest = parser.parse_known_args() + parsed = parse_numeric_cl_kwargs(rest) + if parsed: + print(f"parsed the following generate kwargs: {parsed}") + Path(args.save_dir).mkdir(exist_ok=True) + if args.reference_path is None and Path(args.score_path).exists(): + warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.") + eval_data_dir( + args.input_path, + args.save_dir, + args.model_name, + prefix=args.prefix, + batch_size=args.bs, + fp16=args.fp16, + task=args.task, + local_rank=args.local_rank, + n_obs=args.n_obs, + save_source=args.save_source, + **parsed, + ) + + +if __name__ == "__main__": + # Usage for MT: + run_generate() diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 25d64f7270..b732cde0bf 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl result[unparsed_args[i][2:]] = value return result + + +def write_txt_file(ordered_tgt, path): + f = Path(path).open("w") + for ln in ordered_tgt: + f.write(ln + "\n") + f.flush()