Refactor prepare_seq2seq_batch (#9524)

* Add target contextmanager and rework prepare_seq2seq_batch

* Fix tests, treat BART and Barthez

* Add last tokenizers

* Fix test

* Set src token before calling the superclass

* Remove special behavior for T5

* Remove needless imports

* Remove needless asserts
This commit is contained in:
Sylvain Gugger
2021-01-12 18:19:38 -05:00
committed by GitHub
parent e6ecef711e
commit 063d8d27f4
24 changed files with 169 additions and 700 deletions

View File

@@ -83,6 +83,7 @@ class TokenizerTesterMixin:
from_pretrained_kwargs = None
from_pretrained_filter = None
from_pretrained_vocab_key = "vocab_file"
test_seq2seq = True
def setUp(self) -> None:
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
@require_torch
def test_prepare_seq2seq_batch(self):
if not self.test_seq2seq:
return
tokenizer = self.get_tokenizer()
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
return
# Longer text that will definitely require truncation.
src_text = [
" UN Chief Says There Is No Military Solution in Syria",