From 33d479d2b2bab7a20ec16b2aee64883156186d7e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 14 Sep 2020 15:57:56 -0400 Subject: [PATCH] [s2s] distributed eval in one command (#7124) --- .../seq2seq/aggregate_distributed_results.py | 46 ------- examples/seq2seq/romanian_postprocessing.md | 2 +- examples/seq2seq/run_distributed_eval.py | 126 +++++++++++++++--- examples/seq2seq/utils.py | 36 ++--- 4 files changed, 125 insertions(+), 85 deletions(-) delete mode 100644 examples/seq2seq/aggregate_distributed_results.py diff --git a/examples/seq2seq/aggregate_distributed_results.py b/examples/seq2seq/aggregate_distributed_results.py deleted file mode 100644 index 5e6a8563a2..0000000000 --- a/examples/seq2seq/aggregate_distributed_results.py +++ /dev/null @@ -1,46 +0,0 @@ -from pathlib import Path - -import fire - - -try: - from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file -except ImportError: - from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file - - -def combine_partial_results( - result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False -): - """Write first n lines of each file f in src_dir to dest_dir/f """ - src_dir = Path(result_dir) - save_dir = Path(save_dir) - save_dir.mkdir(exist_ok=True) - paths_to_combine = list(src_dir.glob("rank*.json")) - records = [] - for partial_result in paths_to_combine: - records.extend(load_json(partial_result)) - preds = [x["pred"] for x in records] - labels = [x["label"] for x in records] - score_fn = calculate_bleu if calc_bleu else calculate_rouge - metrics = score_fn(preds, labels) - save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json - print(metrics) - if just_metrics: - return - - if save_prefix is None: - save_prefix = "generated" - print("using generated as prefix") - - tgt_path = save_dir.joinpath(f"{save_prefix}.target") - write_txt_file(labels, tgt_path) - pred_path = save_dir.joinpath(f"{save_prefix}.pred_target") - write_txt_file(preds, pred_path) - if "source" in records[0]: - src_path = save_dir.joinpath(f"{save_prefix}.source") - write_txt_file([x["source"] for x in records], src_path) - - -if __name__ == "__main__": - fire.Fire(combine_partial_results) diff --git a/examples/seq2seq/romanian_postprocessing.md b/examples/seq2seq/romanian_postprocessing.md index aa20829e8a..938f0d1d72 100644 --- a/examples/seq2seq/romanian_postprocessing.md +++ b/examples/seq2seq/romanian_postprocessing.md @@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process. cd $HOME git clone git@github.com:moses-smt/mosesdecoder.git cd mosesdecoder -git@github.com:rsennrich/wmt16-scripts.git +git clone git@github.com:rsennrich/wmt16-scripts.git ``` (2) define a function for post processing. diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index 2478edfdc6..4d7ba6e2c2 100644 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -1,7 +1,10 @@ import argparse +import shutil +import time +from json import JSONDecodeError from logging import getLogger from pathlib import Path -from typing import Dict +from typing import Dict, List, Tuple import torch from torch.utils.data import DataLoader @@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer logger = getLogger(__name__) try: - from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params + from .utils import ( + Seq2SeqDataset, + calculate_bleu, + calculate_rouge, + lmap, + load_json, + parse_numeric_cl_kwargs, + save_json, + use_task_specific_params, + write_txt_file, + ) except ImportError: - from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params - - -DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + from utils import ( + Seq2SeqDataset, + calculate_bleu, + calculate_rouge, + lmap, + load_json, + parse_numeric_cl_kwargs, + save_json, + use_task_specific_params, + write_txt_file, + ) def eval_data_dir( @@ -30,7 +50,6 @@ def eval_data_dir( type_path="val", n_obs=None, fp16=False, - save_source=False, num_beams: int = 4, task="summarization", local_rank=None, @@ -62,7 +81,7 @@ def eval_data_dir( n_obs=n_obs, prefix=model.config.prefix, ) - sampler = ds.make_sortish_sampler(bs, distributed=True) + sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False) data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode results = [] @@ -75,23 +94,19 @@ def eval_data_dir( ) preds = tokenizer.batch_decode(summaries, **dec_kwargs) labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs) - if save_source: - docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs) + ids = batch["ids"] for i in range(len(labels)): label, pred = labels[i], preds[i] - if save_source: - results.append(dict(pred=pred, label=label, source=docs[i])) - else: - results.append(dict(pred=pred, label=label)) + results.append(dict(pred=pred, label=label, id=ids[i].item())) save_json(results, save_path) - return results + return results, sampler.num_replicas def run_generate(): parser = argparse.ArgumentParser( epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" ) - parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source") + parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source") parser.add_argument( "--model_name", type=str, @@ -113,17 +128,31 @@ def run_generate(): parser.add_argument( "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." ) + parser.add_argument( + "--sync_timeout", + type=int, + default=600, + required=False, + help="How long should master process wait for other processes to finish.", + ) parser.add_argument("--fp16", action="store_true") - parser.add_argument("--save_source", action="store_true") - + parser.add_argument("--debug", action="store_true") + start_time = time.time() args, rest = parser.parse_known_args() generate_kwargs = parse_numeric_cl_kwargs(rest) if generate_kwargs: print(f"parsed the following generate kwargs: {generate_kwargs}") + json_save_dir = Path(args.save_dir + "_tmp") + Path(json_save_dir).mkdir(exist_ok=True) # this handles locking. + intermediate_files = list(json_save_dir.glob("rank_*.json")) + if intermediate_files: + raise ValueError(f"Found files at {json_save_dir} please move or remove them.") + # In theory, a node could finish and save before another node hits this. If this happens, we can address later. + Path(args.save_dir).mkdir(exist_ok=True) - eval_data_dir( - args.input_path, - args.save_dir, + results, num_replicas = eval_data_dir( + args.data_dir, + json_save_dir, args.model_name, type_path=args.type_path, batch_size=args.bs, @@ -131,11 +160,64 @@ def run_generate(): task=args.task, local_rank=args.local_rank, n_obs=args.n_obs, - save_source=args.save_source, max_source_length=args.max_source_length, **generate_kwargs, ) + if args.local_rank <= 0: + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True) + partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) + preds, labels = combine_partial_results(partial_results) + # Calculate metrics, save metrics, and save _generations.txt + calc_bleu = "translation" in args.task + score_fn = calculate_bleu if calc_bleu else calculate_rouge + metric_name = "bleu" if calc_bleu else "rouge" + metrics: Dict = score_fn(preds, labels) + metrics["n_obs"] = len(preds) + runtime = time.time() - start_time + metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2) + # TODO(@stas00): add whatever metadata to metrics + metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") + save_json(metrics, metrics_save_path) + print(metrics) + write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) + if args.debug: + write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target")) + else: + shutil.rmtree(json_save_dir) + + +def combine_partial_results(partial_results) -> Tuple[List, List]: + """Concatenate partial results into one file, then sort it by id.""" + records = [] + for partial_result in partial_results: + records.extend(partial_result) + records = list(sorted(records, key=lambda x: x["id"])) + preds = [x["pred"] for x in records] + labels = [x["label"] for x in records] + return preds, labels + + +def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: + # WAIT FOR lots of .json files + start_wait = time.time() + logger.info("waiting for all nodes to finish") + json_data = None + while (time.time() - start_wait) < timeout: + json_files = list(save_dir.glob("rank_*.json")) + if len(json_files) < num_replicas: + continue + try: + # make sure all json files are fully saved + json_data = lmap(load_json, json_files) + return json_data + except JSONDecodeError: + continue + else: + raise TimeoutError("Rank 0 gave up on waiting for other processes") + # Unreachable + if __name__ == "__main__": # Usage for MT: diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 37f49b8073..c049b4372e 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -18,6 +18,7 @@ from torch import nn from torch.utils.data import Dataset, Sampler from transformers import BartTokenizer +from transformers.file_utils import cached_property def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): @@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset): def get_char_lens(data_file): return [len(x) for x in Path(data_file).open().readlines()] - def make_sortish_sampler(self, batch_size, distributed=False): + def make_sortish_sampler(self, batch_size, distributed=False, **kwargs): if distributed: - return DistributedSortishSampler(self, batch_size) + return DistributedSortishSampler(self, batch_size, **kwargs) else: return SortishSampler(self.src_lens, batch_size) @@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") assert source_line, f"empty source line for index {index}" assert tgt_line, f"empty tgt line for index {index}" - return { - "tgt_texts": tgt_line, - "src_texts": source_line, - } + return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} def collate_fn(self, batch) -> Dict[str, torch.Tensor]: """Call prepare_seq2seq_batch.""" - batch_encoding = self.tokenizer.prepare_seq2seq_batch( + batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( [x["src_texts"] for x in batch], src_lang=self.src_lang, tgt_texts=[x["tgt_texts"] for x in batch], @@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): max_target_length=self.max_target_length, return_tensors="pt", add_prefix_space=self.add_prefix_space, - ) - return batch_encoding.data + ).data + batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) + return batch_encoding class SortishSampler(Sampler): @@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array: class DistributedSortishSampler(Sampler): """Copied from torch DistributedSampler""" - def __init__(self, dataset, batch_size, num_replicas=None, rank=None): + def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler): self.num_replicas = num_replicas self.rank = rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas + if add_extra_examples: + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + else: + self.total_size = len(dataset) + self.num_samples = len(self.available_indices) self.batch_size = batch_size + self.add_extra_examples = add_extra_examples def __iter__(self) -> Iterable: g = torch.Generator() g.manual_seed(self.epoch) - available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas] - sortish_data = [self.dataset.src_lens[i] for i in available_indices] + sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size) - indices = [available_indices[i] for i in sortish_indices] + indices = [self.available_indices[i] for i in sortish_indices] assert len(indices) == self.num_samples return iter(indices) - def get_indices_for_rank(self) -> np.array: + @cached_property + def available_indices(self) -> np.array: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))]