TF MT5 embeddings resize (#15567)

* Fix TF MT5 vocab resize

* more assertive testing
This commit is contained in:
Joao Gante
2022-02-11 17:35:10 +00:00
committed by GitHub
parent 8c03df1010
commit 2f40c728c9
3 changed files with 37 additions and 1 deletions

View File

@@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
if is_tf_available():
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
from transformers import AutoTokenizer, T5Tokenizer, TFAutoModelForSeq2SeqLM, TFMT5ForConditionalGeneration
@require_tf
class TFMT5ModelTest(unittest.TestCase): # no mixin with common tests -> most cases are already covered in the TF T5
@slow
def test_resize_embeddings(self):
model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
original_vocab_size = model.get_input_embeddings().weight.shape[0]
# the vocab size is defined in the model config
self.assertEqual(original_vocab_size, model.config.vocab_size)
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
model._resize_token_embeddings(len(tokenizer))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
@require_tf

View File

@@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: Fix head-masking according to PyTorch T5 model
pass
@slow
def test_resize_embeddings(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
original_vocab_size = model.get_input_embeddings().weight.shape[0]
# the vocab size is defined in the model config
self.assertEqual(original_vocab_size, model.config.vocab_size)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
model._resize_token_embeddings(len(tokenizer))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
class TFT5EncoderOnlyModelTester:
def __init__(