[seq2seq] pack_dataset.py rewrites dataset in max_tokens format (#5819)
This commit is contained in:
@@ -16,6 +16,7 @@ from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import SummarizationDataset, lmap, load_json
|
||||
|
||||
@@ -249,6 +250,16 @@ def test_finetune(model):
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
tmp_dir = Path(make_test_data_dir())
|
||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||
new_paths = {x.name for x in save_dir.iterdir()}
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user