prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
|
||||
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
EN_CODE = 250004
|
||||
RO_CODE = 250020
|
||||
|
||||
@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text,
|
||||
tgt_texts=self.tgt_text,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||
# Test that special tokens are reset
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
|
||||
def test_max_target_length(self):
|
||||
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
src_text,
|
||||
return_tensors=None,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.tokenizer.save_pretrained(tmpdirname)
|
||||
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||
|
||||
# prepare_seq2seq_batch tests below
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
for k in batch:
|
||||
batch[k] = batch[k].tolist()
|
||||
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
# batch.decoder_inputs_ids[0][0] ==
|
||||
assert batch.input_ids[1][-2:] == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0] == RO_CODE
|
||||
assert batch.decoder_input_ids[1][-1] == 2
|
||||
assert batch.labels[1][-2:] == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text,
|
||||
tgt_texts=self.tgt_text,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||
# Test that special tokens are reset
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
Reference in New Issue
Block a user