fix retribert's test_torch_encode_plus_sent_to_model (#17231)
This commit is contained in:
@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
|
|||||||
_is_punctuation,
|
_is_punctuation,
|
||||||
_is_whitespace,
|
_is_whitespace,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_tokenizers, slow
|
from transformers.testing_utils import require_tokenizers, require_torch, slow
|
||||||
|
|
||||||
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
|
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
|
||||||
@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
|
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
|
||||||
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
|
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
|
||||||
|
|
||||||
|
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_torch_encode_plus_sent_to_model(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
|
||||||
|
|
||||||
|
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
|
||||||
|
|
||||||
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
|
||||||
|
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
|
||||||
|
return
|
||||||
|
|
||||||
|
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
|
||||||
|
config = config_class()
|
||||||
|
|
||||||
|
if config.is_encoder_decoder or config.pad_token_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# The following test is different from the common's one
|
||||||
|
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
|
||||||
|
|
||||||
|
# Build sequence
|
||||||
|
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
|
||||||
|
sequence = " ".join(first_ten_tokens)
|
||||||
|
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
|
||||||
|
|
||||||
|
# Ensure that the BatchEncoding.to() method works.
|
||||||
|
encoded_sequence.to(model.device)
|
||||||
|
|
||||||
|
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
|
||||||
|
# This should not fail
|
||||||
|
|
||||||
|
with torch.no_grad(): # saves some time
|
||||||
|
# The following lines are different from the common's ones
|
||||||
|
model.embed_questions(**encoded_sequence)
|
||||||
|
model.embed_questions(**batch_encoded_sequence)
|
||||||
|
|||||||
Reference in New Issue
Block a user