Encoder-decoder models: move embedding scale to nn.Module (#30410)
* move scaling to nn.Module * let the test be here for now (need to fix) * failing tests * last failing models * Revert commit 4c14817f38 * clean-up * oops forgot * codestyle * raise NotImplemented when possible * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * skip tests in respective modeling files --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d31b32e9d
commit
38a4bf79ad
@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user