diff --git a/examples/seq2seq/pack_dataset.py b/examples/seq2seq/pack_dataset.py new file mode 100644 index 0000000000..fc2e0d78b4 --- /dev/null +++ b/examples/seq2seq/pack_dataset.py @@ -0,0 +1,63 @@ +"""Fill examples with bitext up to max_tokens without breaking up examples. +[['I went', 'yo fui'], +['to the store', 'a la tienda'] +] +=> ['I went to the store', 'yo fui a la tienda'] +""" + +import argparse +from pathlib import Path + +from tqdm import tqdm + +from transformers import AutoTokenizer + + +def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024): + + finished_src, finished_tgt = [], [] + new_src, new_tgt = "", "" + sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0]))) + + def is_too_big(strang): + return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens + + for src, tgt in tqdm(sorted_examples): + cand_src = new_src + " " + src + cand_tgt = new_tgt + " " + tgt + if is_too_big(cand_src) or is_too_big(cand_tgt): # cant fit, finalize example + finished_src.append(new_src) + finished_tgt.append(new_tgt) + new_src, new_tgt = src, tgt + else: # can fit, keep adding + new_src, new_tgt = cand_src, cand_tgt + + return finished_src, finished_tgt + + +def pack_data_dir(tok, data_dir: Path, max_tokens, save_path): + save_path = Path(save_path) + save_path.mkdir(exist_ok=True) + for split in ["val", "test", "train"]: + src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target" + src_docs = list(Path(src_path).open().readlines()) + tgt_docs = list(Path(tgt_path).open().readlines()) + src, tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens) + print(f"packed {split} split from {len(src_docs)} examples -> {len(src)}.") + Path(save_path / f"{split}.source").open("w").write("\n".join(src)) + Path(save_path / f"{split}.target").open("w").write("\n".join(tgt)) + + +def packer_cli(): + parser = argparse.ArgumentParser() + parser.add_argument("--tok_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") + parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--data_dir", type=str) + parser.add_argument("--save_path", type=str) + args = parser.parse_args() + tokenizer = AutoTokenizer.from_pretrained(args.tok_name) + return pack_data_dir(tokenizer, Path(args.data_dir), args.max_seq_len, args.save_path) + + +if __name__ == "__main__": + packer_cli() diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 3989046e65..0a6c59ab38 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -16,6 +16,7 @@ from transformers.testing_utils import require_multigpu from .distillation import distill_main, evaluate_checkpoint from .finetune import main +from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate from .utils import SummarizationDataset, lmap, load_json @@ -249,6 +250,16 @@ def test_finetune(model): assert bart.decoder.embed_tokens == bart.shared +def test_pack_dataset(): + tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") + tmp_dir = Path(make_test_data_dir()) + save_dir = Path(tempfile.mkdtemp(prefix="packed_")) + pack_data_dir(tokenizer, tmp_dir, 128, save_dir) + orig_paths = {x.name for x in tmp_dir.iterdir()} + new_paths = {x.name for x in save_dir.iterdir()} + assert orig_paths == new_paths + + @pytest.mark.parametrize( ["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] )