Seq2SeqDataset: avoid passing src_lang everywhere (#7470)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
|
||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||
assert ids1.intersection(ids2) == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_dataset_kwargs(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
if tok_name == MBART_TINY:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=make_test_data_dir(),
|
||||
type_path="train",
|
||||
max_source_length=4,
|
||||
max_target_length=8,
|
||||
src_lang="EN",
|
||||
tgt_lang="FR",
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||
else:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||
|
||||
Reference in New Issue
Block a user