prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -1,13 +1,16 @@
import tempfile
import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
if is_torch_available():
from transformers.modeling_bart import shift_tokens_right
EN_CODE = 250004
RO_CODE = 250020
@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length
src_text,
return_tensors=None,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)