Encoder-decoder models: move embedding scale to nn.Module (#30410)

* move scaling to nn.Module

* let the test be here for now (need to fix)

* failing tests

* last failing models

* Revert commit 4c14817f38

* clean-up

* oops forgot

* codestyle

* raise NotImplemented when possible

* Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* skip tests in respective modeling files

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-05-01 12:33:00 +05:00
committed by GitHub
parent 9d31b32e9d
commit 38a4bf79ad
36 changed files with 541 additions and 80 deletions

View File

@@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="AlignVisionModel does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignVisionModel does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass

View File

@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
model(**inputs)[0]
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()

View File

@@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Bridge Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# ViT does not use inputs_embeds
pass
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("CANINE does not have a get_input_embeddings() method.")
def test_model_common_attributes(self):
pass

View File

@@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass

View File

@@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass

View File

@@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self):
pass
@unittest.skip(reason="DETA does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="DETA does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass

View File

@@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self):
pass
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass

View File

@@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self):
pass
@unittest.skip("Input ids is required for FSMT.")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("model weights aren't tied in FSMT.")
def test_tie_model_weights(self):
pass

View File

@@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
def test_model_parallelism(self):
super().test_model_parallelism()
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
def test_model_parallelism(self):
super().test_model_parallelism()
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@slow
def test_logits(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")

View File

@@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad():
model(**inputs)[0]
@unittest.skip("ibert overrides scaling to None if inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch
class IBertModelIntegrationTest(unittest.TestCase):

View File

@@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds():
pass
@unittest.skip("input_embeds cannot be passed in without input_ids")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self):
pass

View File

@@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
with torch.no_grad():
model(**inputs)[0]
# override because ImageGPT main input name is `pixel_values`
# NOTE: in latest transformers this is deprecated, `input_ids` should be used. TODO
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
pixel_values = inputs["pixel_values"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(pixel_values)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return

View File

@@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
@unittest.skip(reason="MusicGen does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
# skip as this model doesn't support all arguments tested
def test_model_outputs_equivalence(self):
pass

View File

@@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
@unittest.skip(reason="MusicGen melody does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("this model doesn't support all arguments tested")
def test_model_outputs_equivalence(self):
pass

View File

@@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(
reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained."
)

View File

@@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(
reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained."
)

View File

@@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Table Transformer does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass

View File

@@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_outputs_equivalence(self):
pass
@unittest.skip(
reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
hidden states. Cannot test equivalence on logit level"""
)
def test_inputs_embeds_matches_input_ids(self):
pass
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True