Encoder-decoder models: move embedding scale to nn.Module (#30410)

* move scaling to nn.Module

* let the test be here for now (need to fix)

* failing tests

* last failing models

* Revert commit 4c14817f38

* clean-up

* oops forgot

* codestyle

* raise NotImplemented when possible

* Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* skip tests in respective modeling files

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-05-01 12:33:00 +05:00
committed by GitHub
parent 9d31b32e9d
commit 38a4bf79ad
36 changed files with 541 additions and 80 deletions

View File

@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
model(**inputs)[0]
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()