Seq2SeqDataset: avoid passing src_lang everywhere (#7470)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Amanpreet Singh
2020-09-30 13:27:48 -04:00
committed by GitHub
parent 08939cfdf7
commit c031d01023
2 changed files with 50 additions and 23 deletions

View File

@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@pytest.mark.parametrize(
"tok_name",
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(),
type_path="train",
max_source_length=4,
max_target_length=8,
src_lang="EN",
tgt_lang="FR",
)
kwargs = train_dataset.dataset_kwargs
assert "src_lang" in kwargs and "tgt_lang" in kwargs
else:
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
)
kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0