Fix test_tf_encode_plus_sent_to_model for LayoutLMv3 (#18898)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,14 @@ from transformers import (
|
|||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer
|
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
|
from transformers.testing_utils import (
|
||||||
|
is_pt_tf_cross_test,
|
||||||
|
require_pandas,
|
||||||
|
require_tf,
|
||||||
|
require_tokenizers,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings
|
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings
|
||||||
|
|
||||||
@@ -2400,3 +2407,39 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||||
def test_np_encode_plus_sent_to_model(self):
|
def test_np_encode_plus_sent_to_model(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@slow
|
||||||
|
def test_tf_encode_plus_sent_to_model(self):
|
||||||
|
from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING
|
||||||
|
|
||||||
|
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING)
|
||||||
|
|
||||||
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
|
||||||
|
return
|
||||||
|
|
||||||
|
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
|
||||||
|
config = config_class()
|
||||||
|
|
||||||
|
if config.is_encoder_decoder or config.pad_token_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# Make sure the model contains at least the full vocabulary size in its embedding matrix
|
||||||
|
self.assertGreaterEqual(model.config.vocab_size, len(tokenizer))
|
||||||
|
|
||||||
|
# Build sequence
|
||||||
|
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
|
||||||
|
boxes = [[1000, 1000, 1000, 1000] for _ in range(len(first_ten_tokens))]
|
||||||
|
encoded_sequence = tokenizer.encode_plus(first_ten_tokens, boxes=boxes, return_tensors="tf")
|
||||||
|
batch_encoded_sequence = tokenizer.batch_encode_plus(
|
||||||
|
[first_ten_tokens, first_ten_tokens], boxes=[boxes, boxes], return_tensors="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should not fail
|
||||||
|
model(encoded_sequence)
|
||||||
|
model(batch_encoded_sequence)
|
||||||
|
|||||||
Reference in New Issue
Block a user