Refactor prepare_seq2seq_batch (#9524)
* Add target contextmanager and rework prepare_seq2seq_batch * Fix tests, treat BART and Barthez * Add last tokenizers * Fix test * Set src token before calling the superclass * Remove special behavior for T5 * Remove needless imports * Remove needless asserts
This commit is contained in:
@@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||
def test_batch_generation_en_ROMANCE_multi(self):
|
||||
self._assert_generated_batch_equal_expected()
|
||||
|
||||
def test_tokenizer_handles_empty(self):
|
||||
normalized = self.tokenizer.normalize("")
|
||||
self.assertIsInstance(normalized, str)
|
||||
with self.assertRaises(ValueError):
|
||||
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
|
||||
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
device = 0 if torch_device == "cuda" else -1
|
||||
|
||||
@@ -83,6 +83,7 @@ class TokenizerTesterMixin:
|
||||
from_pretrained_kwargs = None
|
||||
from_pretrained_filter = None
|
||||
from_pretrained_vocab_key = "vocab_file"
|
||||
test_seq2seq = True
|
||||
|
||||
def setUp(self) -> None:
|
||||
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
|
||||
@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
if not self.test_seq2seq:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
|
||||
return
|
||||
# Longer text that will definitely require truncation.
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
|
||||
@@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = CTRLTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
rust_tokenizer_class = GPT2TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_kwargs = {"add_prefix_space": True}
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@@ -31,6 +31,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = OpenAIGPTTokenizer
|
||||
rust_tokenizer_class = OpenAIGPTTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@@ -33,6 +33,7 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = ReformerTokenizer
|
||||
rust_tokenizer_class = ReformerTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@@ -44,6 +44,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
test_rust_tokenizer = False
|
||||
space_between_special_tokens = True
|
||||
from_pretrained_filter = filter_non_english
|
||||
test_seq2seq = False
|
||||
|
||||
def get_table(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = TransfoXLTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
Reference in New Issue
Block a user