From f6cb0f806efecb64df40c946dacaad0adad33d53 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 11 Aug 2020 09:04:17 -0700 Subject: [PATCH] [s2s] wmt download script use less ram (#6405) --- examples/seq2seq/download_wmt.py | 38 ++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/examples/seq2seq/download_wmt.py b/examples/seq2seq/download_wmt.py index 7994ab44c4..294a489a84 100644 --- a/examples/seq2seq/download_wmt.py +++ b/examples/seq2seq/download_wmt.py @@ -4,44 +4,48 @@ import fire from tqdm import tqdm -def download_wmt_dataset(src_lang, tgt_lang, dataset="wmt19", save_dir=None) -> None: +def download_wmt_dataset(src_lang="ro", tgt_lang="en", dataset="wmt16", save_dir=None) -> None: """Download a dataset using the nlp package and save it to the format expected by finetune.py Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target. Args: src_lang: source language tgt_lang: target language - dataset: like wmt19 (if you don't know, try wmt19). + dataset: wmt16, wmt17, etc. wmt16 is a good start as it's small. To get the full list run `import nlp; print([d.id for d in nlp.list_datasets() if "wmt" in d.id])` save_dir: , where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}' Usage: - >>> download_wmt_dataset('en', 'ru', dataset='wmt19') # saves to wmt19_en_ru + >>> download_wmt_dataset('ro', 'en', dataset='wmt16') # saves to wmt16-ro-en """ try: import nlp except (ModuleNotFoundError, ImportError): raise ImportError("run pip install nlp") pair = f"{src_lang}-{tgt_lang}" + print(f"Converting {dataset}-{pair}") ds = nlp.load_dataset(dataset, pair) if save_dir is None: save_dir = f"{dataset}-{pair}" save_dir = Path(save_dir) save_dir.mkdir(exist_ok=True) - for split in tqdm(ds.keys()): - tr_list = list(ds[split]) - data = [x["translation"] for x in tr_list] - src, tgt = [], [] - for example in data: - src.append(example[src_lang]) - tgt.append(example[tgt_lang]) - if split == "validation": - split = "val" # to save to val.source, val.target like summary datasets - src_path = save_dir.joinpath(f"{split}.source") - src_path.open("w+").write("\n".join(src)) - tgt_path = save_dir.joinpath(f"{split}.target") - tgt_path.open("w+").write("\n".join(tgt)) - print(f"saved dataset to {save_dir}") + for split in ds.keys(): + print(f"Splitting {split} with {ds[split].num_rows} records") + + # to save to val.source, val.target like summary datasets + fn = "val" if split == "validation" else split + src_path = save_dir.joinpath(f"{fn}.source") + tgt_path = save_dir.joinpath(f"{fn}.target") + src_fp = src_path.open("w+") + tgt_fp = tgt_path.open("w+") + + # reader is the bottleneck so writing one record at a time doesn't slow things down + for x in tqdm(ds[split]): + ex = x["translation"] + src_fp.write(ex[src_lang] + "\n") + tgt_fp.write(ex[tgt_lang] + "\n") + + print(f"Saved {dataset} dataset to {save_dir}") if __name__ == "__main__":