From f1a4e06f1fe2baaf85799db2b0316991ee1a2405 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 20 Jul 2020 15:18:26 -0400 Subject: [PATCH] [Fix] seq2seq pack_dataset.py actually packs (#5913) Huge MT speedup! --- examples/seq2seq/pack_dataset.py | 33 +++++++++++++++++------ examples/seq2seq/test_seq2seq_examples.py | 8 ++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/examples/seq2seq/pack_dataset.py b/examples/seq2seq/pack_dataset.py index fc2e0d78b4..599d133a62 100644 --- a/examples/seq2seq/pack_dataset.py +++ b/examples/seq2seq/pack_dataset.py @@ -16,13 +16,14 @@ from transformers import AutoTokenizer def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024): finished_src, finished_tgt = [], [] - new_src, new_tgt = "", "" + sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0]))) + new_src, new_tgt = sorted_examples[0] def is_too_big(strang): return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens - for src, tgt in tqdm(sorted_examples): + for src, tgt in tqdm(sorted_examples[1:]): cand_src = new_src + " " + src cand_tgt = new_tgt + " " + tgt if is_too_big(cand_src) or is_too_big(cand_tgt): # cant fit, finalize example @@ -31,21 +32,37 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024): new_src, new_tgt = src, tgt else: # can fit, keep adding new_src, new_tgt = cand_src, cand_tgt + # import ipdb; ipdb.set_trace() + # cleanup + if new_src: + assert new_tgt + finished_src.append(new_src) + finished_tgt.append(new_tgt) return finished_src, finished_tgt +def minify(src_dir: Path, dest_dir: Path, n: int): + """Write first n lines of each file f in src_dir to dest_dir/f""" + dest_dir.mkdir(exist_ok=True) + for path in src_dir.iterdir(): + new = [x.rstrip() for x in list(path.open().readlines())][:n] + dest_path = dest_dir.joinpath(path.name) + print(dest_path) + dest_path.open("w").write("\n".join(new)) + + def pack_data_dir(tok, data_dir: Path, max_tokens, save_path): save_path = Path(save_path) save_path.mkdir(exist_ok=True) for split in ["val", "test", "train"]: src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target" - src_docs = list(Path(src_path).open().readlines()) - tgt_docs = list(Path(tgt_path).open().readlines()) - src, tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens) - print(f"packed {split} split from {len(src_docs)} examples -> {len(src)}.") - Path(save_path / f"{split}.source").open("w").write("\n".join(src)) - Path(save_path / f"{split}.target").open("w").write("\n".join(tgt)) + src_docs = [x.rstrip() for x in Path(src_path).open().readlines()] + tgt_docs = [x.rstrip() for x in Path(tgt_path).open().readlines()] + packed_src, packed_tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens) + print(f"packed {split} split from {len(src_docs)} examples -> {len(packed_src)}.") + Path(save_path / f"{split}.source").open("w").write("\n".join(packed_src)) + Path(save_path / f"{split}.target").open("w").write("\n".join(packed_tgt)) def packer_cli(): diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index abf9b908a6..6e9117228a 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -254,11 +254,19 @@ def test_finetune(model): def test_pack_dataset(): tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") + tmp_dir = Path(make_test_data_dir()) + orig_examples = tmp_dir.joinpath("train.source").open().readlines() save_dir = Path(tempfile.mkdtemp(prefix="packed_")) pack_data_dir(tokenizer, tmp_dir, 128, save_dir) orig_paths = {x.name for x in tmp_dir.iterdir()} new_paths = {x.name for x in save_dir.iterdir()} + packed_examples = save_dir.joinpath("train.source").open().readlines() + # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.'] + # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.'] + assert len(packed_examples) < len(orig_examples) + assert len(packed_examples) == 1 + assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) assert orig_paths == new_paths