tf add resize_token_embeddings method (#4351)
* resize token embeddings * add tokens * add tokens * add tokens * add t5 token method * add t5 token method * add t5 token method * typo * debugging input * debugging input * debug * debug * debug * trying to set embedding tokens properly * set embeddings for generation head too * set embeddings for generation head too * debugging * debugging * enable generation * add base method * add base method * add base method * return logits in the main call * reverting to generation * revert back * set embeddings for the bert main layer * description * fix conflicts * logging * set base model as self * refactor * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * v0 * v0 * finalize * final * black * add tests * revert back the emb call * comments * comments * add the second test * add vocab size condig * add tf models * add tf models. add common tests * remove model specific embedding tests * stylish * remove files * stylez * Update src/transformers/modeling_tf_transfo_xl.py change the error. Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * adding unchanged weight test Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -472,6 +472,30 @@ class TFModelTesterMixin:
|
||||
|
||||
model(inputs)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
INPUT_SHAPE = [1, 10, config.hidden_size]
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
emb_old = model.get_input_embeddings()
|
||||
emb_old.build(INPUT_SHAPE)
|
||||
# reshape the embeddings
|
||||
new_embeddings = model._get_resized_embeddings(emb_old, size)
|
||||
# # check that the the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
self.assertEqual(new_embeddings.shape[0], assert_size)
|
||||
# check that weights remain the same after resizing
|
||||
emd_old_weights = model._get_word_embeddings(emb_old)
|
||||
models_equal = True
|
||||
for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
|
||||
if np.sum(abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
@@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user