[pack_dataset] don't sort before packing, only pack train (#5954)
This commit is contained in:
@@ -13,3 +13,4 @@ streamlit
|
||||
elasticsearch
|
||||
pandas
|
||||
nlp
|
||||
fire
|
||||
|
||||
19
examples/seq2seq/minify_dataset.py
Normal file
19
examples/seq2seq/minify_dataset.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
def minify(src_dir: str, dest_dir: str, n: int):
|
||||
"""Write first n lines of each file f in src_dir to dest_dir/f """
|
||||
src_dir = Path(src_dir)
|
||||
dest_dir = Path(dest_dir)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
for path in src_dir.iterdir():
|
||||
new = [x.rstrip() for x in list(path.open().readlines())][:n]
|
||||
dest_path = dest_dir.joinpath(path.name)
|
||||
print(dest_path)
|
||||
dest_path.open("w").write("\n".join(new))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(minify)
|
||||
@@ -6,6 +6,7 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -17,7 +18,7 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||
|
||||
finished_src, finished_tgt = [], []
|
||||
|
||||
sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0])))
|
||||
sorted_examples = list(zip(src_examples, tgt_examples))
|
||||
new_src, new_tgt = sorted_examples[0]
|
||||
|
||||
def is_too_big(strang):
|
||||
@@ -42,20 +43,10 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||
return finished_src, finished_tgt
|
||||
|
||||
|
||||
def minify(src_dir: Path, dest_dir: Path, n: int):
|
||||
"""Write first n lines of each file f in src_dir to dest_dir/f"""
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
for path in src_dir.iterdir():
|
||||
new = [x.rstrip() for x in list(path.open().readlines())][:n]
|
||||
dest_path = dest_dir.joinpath(path.name)
|
||||
print(dest_path)
|
||||
dest_path.open("w").write("\n".join(new))
|
||||
|
||||
|
||||
def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(exist_ok=True)
|
||||
for split in ["val", "test", "train"]:
|
||||
for split in ["train"]:
|
||||
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
||||
src_docs = [x.rstrip() for x in Path(src_path).open().readlines()]
|
||||
tgt_docs = [x.rstrip() for x in Path(tgt_path).open().readlines()]
|
||||
@@ -63,6 +54,10 @@ def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
|
||||
print(f"packed {split} split from {len(src_docs)} examples -> {len(packed_src)}.")
|
||||
Path(save_path / f"{split}.source").open("w").write("\n".join(packed_src))
|
||||
Path(save_path / f"{split}.target").open("w").write("\n".join(packed_tgt))
|
||||
for split in ["val", "test"]:
|
||||
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
||||
shutil.copyfile(src_path, save_path / f"{split}.source")
|
||||
shutil.copyfile(tgt_path, save_path / f"{split}.target")
|
||||
|
||||
|
||||
def packer_cli():
|
||||
|
||||
Reference in New Issue
Block a user