Refactor prepare_seq2seq_batch (#9524)

* Add target contextmanager and rework prepare_seq2seq_batch

* Fix tests, treat BART and Barthez

* Add last tokenizers

* Fix test

* Set src token before calling the superclass

* Remove special behavior for T5

* Remove needless imports

* Remove needless asserts
This commit is contained in:
Sylvain Gugger
2021-01-12 18:19:38 -05:00
committed by GitHub
parent e6ecef711e
commit 063d8d27f4
24 changed files with 169 additions and 700 deletions

View File

@@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected()
def test_tokenizer_handles_empty(self):
normalized = self.tokenizer.normalize("")
self.assertIsInstance(normalized, str)
with self.assertRaises(ValueError):
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
@slow
def test_pipeline(self):
device = 0 if torch_device == "cuda" else -1

View File

@@ -83,6 +83,7 @@ class TokenizerTesterMixin:
from_pretrained_kwargs = None
from_pretrained_filter = None
from_pretrained_vocab_key = "vocab_file"
test_seq2seq = True
def setUp(self) -> None:
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
@require_torch
def test_prepare_seq2seq_batch(self):
if not self.test_seq2seq:
return
tokenizer = self.get_tokenizer()
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
return
# Longer text that will definitely require truncation.
src_text = [
" UN Chief Says There Is No Military Solution in Syria",

View File

@@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CTRLTokenizer
test_rust_tokenizer = False
test_seq2seq = False
def setUp(self):
super().setUp()

View File

@@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True
from_pretrained_kwargs = {"add_prefix_space": True}
test_seq2seq = False
def setUp(self):
super().setUp()

View File

@@ -31,6 +31,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = OpenAIGPTTokenizer
rust_tokenizer_class = OpenAIGPTTokenizerFast
test_rust_tokenizer = True
test_seq2seq = False
def setUp(self):
super().setUp()

View File

@@ -33,6 +33,7 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = ReformerTokenizer
rust_tokenizer_class = ReformerTokenizerFast
test_rust_tokenizer = True
test_seq2seq = False
def setUp(self):
super().setUp()

View File

@@ -44,6 +44,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
test_rust_tokenizer = False
space_between_special_tokens = True
from_pretrained_filter = filter_non_english
test_seq2seq = False
def get_table(
self,

View File

@@ -26,6 +26,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = TransfoXLTokenizer
test_rust_tokenizer = False
test_seq2seq = False
def setUp(self):
super().setUp()