rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)
This commit is contained in:
@@ -1522,3 +1522,37 @@ class TokenizerTesterMixin:
|
||||
|
||||
if batch_encoded_sequence_fast is None:
|
||||
raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)")
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
|
||||
return
|
||||
# Longer text that will definitely require truncation.
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
|
||||
]
|
||||
tgt_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
"Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei "
|
||||
'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu '
|
||||
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||
]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
|
||||
self.assertNotIn("decoder_input_ids", batch_encoder_only)
|
||||
|
||||
Reference in New Issue
Block a user