[Pretrained Model] Add resize_position_embeddings (#13559)
* finish * delete bogus file * correct some stuff * finish * finish
This commit is contained in:
committed by
GitHub
parent
c783e14887
commit
95f933ea85
@@ -94,6 +94,7 @@ class ModelTesterMixin:
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
test_resize_position_embeddings = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
@@ -1067,6 +1068,85 @@ class ModelTesterMixin:
|
||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def test_resize_position_vector_embeddings(self):
|
||||
if not self.test_resize_position_embeddings:
|
||||
return
|
||||
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Retrieve the embeddings and clone theme
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||
encoder_cloned_embeddings = encoder_model_embed.weight.clone()
|
||||
decoder_cloned_embeddings = decoder_model_embed.weight.clone()
|
||||
else:
|
||||
model_embed = model.get_position_embeddings()
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the position embeddings with a larger max_position_embeddings increases
|
||||
# the model's postion embeddings size
|
||||
model.resize_position_embeddings(max_position_embeddings + 10)
|
||||
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings + 10)
|
||||
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] + 10)
|
||||
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] + 10)
|
||||
else:
|
||||
model_embed = model.get_position_embeddings()
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the position embeddings with a smaller max_position_embeddings decreases
|
||||
# the model's max_position_embeddings
|
||||
model.resize_position_embeddings(max_position_embeddings - 5)
|
||||
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings - 5)
|
||||
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] - 5)
|
||||
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] - 5)
|
||||
else:
|
||||
model_embed = model.get_position_embeddings()
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 5)
|
||||
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
models_equal = True
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
for p1, p2 in zip(encoder_cloned_embeddings, encoder_model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
for p1, p2 in zip(decoder_cloned_embeddings, decoder_model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
else:
|
||||
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(
|
||||
original_config,
|
||||
|
||||
@@ -214,6 +214,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_sequence_classification_problem_types = True
|
||||
test_resize_position_embeddings = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DistilBertModelTester(self)
|
||||
|
||||
@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_resize_position_embeddings = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
@@ -526,6 +527,7 @@ class PegasusStandaloneDecoderModelTester:
|
||||
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
|
||||
test_resize_position_embeddings = True
|
||||
test_pruning = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user