From 59b7334c87a395e34900a875d802b32c5d126045 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:10:36 +0200 Subject: [PATCH] Fix `test_tf_encode_plus_sent_to_model` for `TAPAS` (#19559) Co-authored-by: ydshieh --- tests/models/tapas/test_tokenization_tapas.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index f712f324f9..ff873c76cd 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -143,8 +143,39 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): return input_text, output_text @require_tensorflow_probability + @slow def test_tf_encode_plus_sent_to_model(self): - super().test_tf_encode_plus_sent_to_model() + from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # Make sure the model contains at least the full vocabulary size in its embedding matrix + self.assertGreaterEqual(model.config.vocab_size, len(tokenizer)) + + # Build sequence + first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] + sequence = " ".join(first_ten_tokens) + table = self.get_table(tokenizer, length=0) + encoded_sequence = tokenizer.encode_plus(table, sequence, return_tensors="tf") + batch_encoded_sequence = tokenizer.batch_encode_plus(table, [sequence, sequence], return_tensors="tf") + + # This should not fail + model(encoded_sequence) + model(batch_encoded_sequence) def test_rust_and_python_full_tokenizers(self): if not self.test_rust_tokenizer: