@@ -16,13 +16,14 @@ from transformers import AutoTokenizer
|
|||||||
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||||
|
|
||||||
finished_src, finished_tgt = [], []
|
finished_src, finished_tgt = [], []
|
||||||
new_src, new_tgt = "", ""
|
|
||||||
sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0])))
|
sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0])))
|
||||||
|
new_src, new_tgt = sorted_examples[0]
|
||||||
|
|
||||||
def is_too_big(strang):
|
def is_too_big(strang):
|
||||||
return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens
|
return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens
|
||||||
|
|
||||||
for src, tgt in tqdm(sorted_examples):
|
for src, tgt in tqdm(sorted_examples[1:]):
|
||||||
cand_src = new_src + " " + src
|
cand_src = new_src + " " + src
|
||||||
cand_tgt = new_tgt + " " + tgt
|
cand_tgt = new_tgt + " " + tgt
|
||||||
if is_too_big(cand_src) or is_too_big(cand_tgt): # cant fit, finalize example
|
if is_too_big(cand_src) or is_too_big(cand_tgt): # cant fit, finalize example
|
||||||
@@ -31,21 +32,37 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
|||||||
new_src, new_tgt = src, tgt
|
new_src, new_tgt = src, tgt
|
||||||
else: # can fit, keep adding
|
else: # can fit, keep adding
|
||||||
new_src, new_tgt = cand_src, cand_tgt
|
new_src, new_tgt = cand_src, cand_tgt
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
if new_src:
|
||||||
|
assert new_tgt
|
||||||
|
finished_src.append(new_src)
|
||||||
|
finished_tgt.append(new_tgt)
|
||||||
return finished_src, finished_tgt
|
return finished_src, finished_tgt
|
||||||
|
|
||||||
|
|
||||||
|
def minify(src_dir: Path, dest_dir: Path, n: int):
|
||||||
|
"""Write first n lines of each file f in src_dir to dest_dir/f"""
|
||||||
|
dest_dir.mkdir(exist_ok=True)
|
||||||
|
for path in src_dir.iterdir():
|
||||||
|
new = [x.rstrip() for x in list(path.open().readlines())][:n]
|
||||||
|
dest_path = dest_dir.joinpath(path.name)
|
||||||
|
print(dest_path)
|
||||||
|
dest_path.open("w").write("\n".join(new))
|
||||||
|
|
||||||
|
|
||||||
def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
|
def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
|
||||||
save_path = Path(save_path)
|
save_path = Path(save_path)
|
||||||
save_path.mkdir(exist_ok=True)
|
save_path.mkdir(exist_ok=True)
|
||||||
for split in ["val", "test", "train"]:
|
for split in ["val", "test", "train"]:
|
||||||
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
||||||
src_docs = list(Path(src_path).open().readlines())
|
src_docs = [x.rstrip() for x in Path(src_path).open().readlines()]
|
||||||
tgt_docs = list(Path(tgt_path).open().readlines())
|
tgt_docs = [x.rstrip() for x in Path(tgt_path).open().readlines()]
|
||||||
src, tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens)
|
packed_src, packed_tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens)
|
||||||
print(f"packed {split} split from {len(src_docs)} examples -> {len(src)}.")
|
print(f"packed {split} split from {len(src_docs)} examples -> {len(packed_src)}.")
|
||||||
Path(save_path / f"{split}.source").open("w").write("\n".join(src))
|
Path(save_path / f"{split}.source").open("w").write("\n".join(packed_src))
|
||||||
Path(save_path / f"{split}.target").open("w").write("\n".join(tgt))
|
Path(save_path / f"{split}.target").open("w").write("\n".join(packed_tgt))
|
||||||
|
|
||||||
|
|
||||||
def packer_cli():
|
def packer_cli():
|
||||||
|
|||||||
@@ -254,11 +254,19 @@ def test_finetune(model):
|
|||||||
|
|
||||||
def test_pack_dataset():
|
def test_pack_dataset():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||||
|
|
||||||
tmp_dir = Path(make_test_data_dir())
|
tmp_dir = Path(make_test_data_dir())
|
||||||
|
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
||||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||||
new_paths = {x.name for x in save_dir.iterdir()}
|
new_paths = {x.name for x in save_dir.iterdir()}
|
||||||
|
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||||
|
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||||
|
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||||
|
assert len(packed_examples) < len(orig_examples)
|
||||||
|
assert len(packed_examples) == 1
|
||||||
|
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||||
assert orig_paths == new_paths
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user