[seq2seq] pack_dataset.py rewrites dataset in max_tokens format (#5819)

This commit is contained in:
Sam Shleifer
2020-07-16 14:06:49 -04:00
committed by GitHub
parent c45d7a707d
commit 283500ff9f
2 changed files with 74 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
@@ -249,6 +250,16 @@ def test_finetune(model):
assert bart.decoder.embed_tokens == bart.shared
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir())
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
assert orig_paths == new_paths
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)