Encoder-decoder models: move embedding scale to nn.Module (#30410)
* move scaling to nn.Module * let the test be here for now (need to fix) * failing tests * last failing models * Revert commit 4c14817f38 * clean-up * oops forgot * codestyle * raise NotImplemented when possible * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * skip tests in respective modeling files --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d31b32e9d
commit
38a4bf79ad
@@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignVisionModel does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignVisionModel does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
@@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Align does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
@@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Align does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
@@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Bridge Tower does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
@@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
# ViT does not use inputs_embeds
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("CANINE does not have a get_input_embeddings() method.")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETA does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETA does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Input ids is required for FSMT.")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model weights aren't tied in FSMT.")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_logits(self):
|
||||
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||
|
||||
@@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@unittest.skip("ibert overrides scaling to None if inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class IBertModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds():
|
||||
pass
|
||||
|
||||
@unittest.skip("input_embeds cannot be passed in without input_ids")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Model does not support padding right")
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
pass
|
||||
|
||||
@@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override because ImageGPT main input name is `pixel_values`
|
||||
# NOTE: in latest transformers this is deprecated, `input_ids` should be used. TODO
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(pixel_values)
|
||||
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
||||
@@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
@unittest.skip(reason="MusicGen does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
# skip as this model doesn't support all arguments tested
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
@unittest.skip(reason="MusicGen melody does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("this model doesn't support all arguments tested")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained."
|
||||
)
|
||||
|
||||
@@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained."
|
||||
)
|
||||
|
||||
@@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Table Transformer does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
|
||||
hidden states. Cannot test equivalence on logit level"""
|
||||
)
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
Reference in New Issue
Block a user