[QOL] add signature for prepare_seq2seq_batch (#7108)
This commit is contained in:
@@ -1566,14 +1566,17 @@ class TokenizerTesterMixin:
|
||||
'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu '
|
||||
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||
]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text,
|
||||
tgt_texts=tgt_text,
|
||||
max_length=3,
|
||||
max_target_length=10,
|
||||
return_tensors="pt",
|
||||
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
|
||||
)
|
||||
try:
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text,
|
||||
tgt_texts=tgt_text,
|
||||
max_length=3,
|
||||
max_target_length=10,
|
||||
return_tensors="pt",
|
||||
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
|
||||
)
|
||||
except NotImplementedError:
|
||||
return
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
|
||||
Reference in New Issue
Block a user