prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -18,7 +18,7 @@ import unittest
import timeout_decorator # noqa
from transformers import BatchEncoding, is_torch_available
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
@@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")
tok = self.default_tokenizer
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus(
@@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length
def test_prepare_seq2seq_batch(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
def test_empty_target_text(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
def test_special_tokens(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())
@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):

View File

@@ -0,0 +1,145 @@
import json
import os
import unittest
from transformers import BartTokenizer, BartTokenizerFast, BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from transformers.tokenization_roberta import VOCAB_FILES_NAMES
from .test_tokenization_common import TokenizerTesterMixin
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BartTokenizer
def setUp(self):
super().setUp()
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return BartTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
return "lower newer", "lower newer"
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@cached_property
def default_tokenizer_fast(self):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch
def test_prepare_seq2seq_batch(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no labels
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("labels", batch)
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
@require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
@require_torch
def test_special_tokens(self):
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
labels = batch["labels"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())

View File

@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
src_texts=src_text,
tgt_texts=tgt_text,
max_length=3,
max_target_length=10,
return_tensors="pt",
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
self.assertEqual(batch.labels.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
self.assertEqual(batch.labels.shape[1], 3)
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"

View File

@@ -1,13 +1,16 @@
import tempfile
import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
if is_torch_available():
from transformers.modeling_bart import shift_tokens_right
EN_CODE = 250004
RO_CODE = 250020
@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length
src_text,
return_tensors=None,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)

View File

@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "decoder_input_ids" in batch # because tgt_texts was specified
assert batch.decoder_input_ids.shape == (2, 5)
assert batch.decoder_attention_mask.shape == (2, 5)
assert len(batch) == 4 # no extra keys
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel

View File

@@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
@@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
@@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@slow
def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer = self.tokenizer_class.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
@@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertNotEqual(first_char, space_encoding)
# Testing spaces after special tokenss
# Testing spaces after special tokens
mask = "<mask>"
tokenizer.add_special_tokens(
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}

View File

@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_max_target_length(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
@@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
@@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.decoder_input_ids.numpy()[0])
tgt_ids = list(batch.labels.numpy()[0])
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)