From 0203ad43bcd0b29423dec6ca1a58ed58300f0d61 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 16 Sep 2020 15:38:37 -0400 Subject: [PATCH] [s2s] distributed eval cleanup (#7186) --- examples/seq2seq/README.md | 14 +++++++++ examples/seq2seq/run_distributed_eval.py | 37 ++++++++++++------------ examples/seq2seq/run_eval.py | 2 +- examples/seq2seq/utils.py | 25 +++++++++------- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 7223f87c3e..ac71ba67e0 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ --fp16 \ --bs 32 ``` +### Multi-GPU Evalulation +here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases +because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have +`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs. + +```bash +python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \ + --model_name sshleifer/distilbart-large-xsum-12-3 \ + --save_dir xsum_generations \ + --data_dir xsum \ + --fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py +``` + +Contributions that implement this command for other distributed hardware setups are welcome! #### run_eval tips and tricks diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index 4d7ba6e2c2..da7a6b7fa1 100644 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -4,7 +4,7 @@ import time from json import JSONDecodeError from logging import getLogger from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List import torch from torch.utils.data import DataLoader @@ -22,7 +22,7 @@ try: calculate_rouge, lmap, load_json, - parse_numeric_cl_kwargs, + parse_numeric_n_bool_cl_kwargs, save_json, use_task_specific_params, write_txt_file, @@ -34,7 +34,7 @@ except ImportError: calculate_rouge, lmap, load_json, - parse_numeric_cl_kwargs, + parse_numeric_n_bool_cl_kwargs, save_json, use_task_specific_params, write_txt_file, @@ -50,7 +50,6 @@ def eval_data_dir( type_path="val", n_obs=None, fp16=False, - num_beams: int = 4, task="summarization", local_rank=None, **generate_kwargs, @@ -81,23 +80,21 @@ def eval_data_dir( n_obs=n_obs, prefix=model.config.prefix, ) - sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False) + # I set shuffle=True for a more accurate progress bar. + # If all the longest samples are first, the prog bar estimate is too high at the beginning. + sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True) data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) - dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode results = [] for batch in tqdm(data_loader): summaries = model.generate( input_ids=batch["input_ids"].to(model.device), attention_mask=batch["attention_mask"].to(model.device), - num_beams=num_beams, **generate_kwargs, ) - preds = tokenizer.batch_decode(summaries, **dec_kwargs) - labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs) + preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) ids = batch["ids"] - for i in range(len(labels)): - label, pred = labels[i], preds[i] - results.append(dict(pred=pred, label=label, id=ids[i].item())) + for i, pred in enumerate(preds): + results.append(dict(pred=pred, id=ids[i].item())) save_json(results, save_path) return results, sampler.num_replicas @@ -139,8 +136,8 @@ def run_generate(): parser.add_argument("--debug", action="store_true") start_time = time.time() args, rest = parser.parse_known_args() - generate_kwargs = parse_numeric_cl_kwargs(rest) - if generate_kwargs: + generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest) + if generate_kwargs and args.local_rank <= 0: print(f"parsed the following generate kwargs: {generate_kwargs}") json_save_dir = Path(args.save_dir + "_tmp") Path(json_save_dir).mkdir(exist_ok=True) # this handles locking. @@ -168,7 +165,10 @@ def run_generate(): save_dir = Path(args.save_dir) save_dir.mkdir(exist_ok=True) partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) - preds, labels = combine_partial_results(partial_results) + preds = combine_partial_results(partial_results) + tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target") + labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)] + # Calculate metrics, save metrics, and save _generations.txt calc_bleu = "translation" in args.task score_fn = calculate_bleu if calc_bleu else calculate_rouge @@ -179,7 +179,7 @@ def run_generate(): metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2) # TODO(@stas00): add whatever metadata to metrics metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") - save_json(metrics, metrics_save_path) + save_json(metrics, metrics_save_path, indent=None) print(metrics) write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) if args.debug: @@ -188,15 +188,14 @@ def run_generate(): shutil.rmtree(json_save_dir) -def combine_partial_results(partial_results) -> Tuple[List, List]: +def combine_partial_results(partial_results) -> List: """Concatenate partial results into one file, then sort it by id.""" records = [] for partial_result in partial_results: records.extend(partial_result) records = list(sorted(records, key=lambda x: x["id"])) preds = [x["pred"] for x in records] - labels = [x["label"] for x in records] - return preds, labels + return preds def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index 23909a3d07..4b2a551e7a 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -156,7 +156,7 @@ def run_generate(verbose=True): scores["info"] = args.info if verbose: - print(*scores) + print(scores) if args.score_path is not None: path = args.score_path diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index b1d35cf10e..dd3348c3ba 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset): def get_char_lens(data_file): return [len(x) for x in Path(data_file).open().readlines()] - def make_sortish_sampler(self, batch_size, distributed=False, **kwargs): + def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): if distributed: - return DistributedSortishSampler(self, batch_size, **kwargs) + return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) else: - return SortishSampler(self.src_lens, batch_size) + return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) def __getitem__(self, item): raise NotImplementedError("You must implement this") @@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): class SortishSampler(Sampler): "Go through the text data by order of src length with a bit of randomness. From fastai repo." - def __init__(self, data, batch_size): - self.data, self.bs = data, batch_size + def __init__(self, data, batch_size, shuffle=True): + self.data, self.bs, self.shuffle = data, batch_size, shuffle def __len__(self) -> int: return len(self.data) def __iter__(self): - return iter(sortish_sampler_indices(self.data, self.bs)) + return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) -def sortish_sampler_indices(data: List, bs: int) -> np.array: +def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: "Go through the text data by order of src length with a bit of randomness. From fastai repo." + if not shuffle: + return np.argsort(np.array(data) * -1) def key_fn(i): return data[i] @@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array: class DistributedSortishSampler(Sampler): """Copied from torch DistributedSampler""" - def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True): + def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler): self.num_samples = len(self.available_indices) self.batch_size = batch_size self.add_extra_examples = add_extra_examples + self.shuffle = shuffle def __iter__(self) -> Iterable: g = torch.Generator() g.manual_seed(self.epoch) sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] - sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size) + sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) indices = [self.available_indices[i] for i in sortish_indices] assert len(indices) == self.num_samples return iter(indices) @@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None: save_json(repo_infos, os.path.join(folder_path, "git_log.json")) -def save_json(content, path): +def save_json(content, path, indent=4, **json_dump_kwargs): with open(path, "w") as f: - json.dump(content, f, indent=4) + json.dump(content, f, indent=indent, **json_dump_kwargs) def load_json(path):