Add inputs_embeds param to ModernBertModel (#35373)

* update modular_modernbert -- add inputs_embeds param to ModernBertModel

* Fix implementation issues; extend to other classes; docstring

First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented.

I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes.

I also introduced an error if input_ids and input_embeds are both or neither provided.

Lastly, I fixed an issue with device being based solely on input_ids with attention_mask.

* Propagate inputs_embeds to ModernBertForMaskedLM correctly

Also reintroduce inputs_embeds test

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
This commit is contained in:
Jack Morris
2025-01-09 08:17:26 -05:00
committed by GitHub
parent 1b2f942af7
commit 832c6191ed
3 changed files with 141 additions and 57 deletions

View File

@@ -146,7 +146,11 @@ class ModernBertModelTester:
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
# that compilation doesn't work. Users can then set compile=False when loading the model,
# much like here. We're testing whether it works once they've done that.
if test_name == "test_retain_grad_hidden_states_attentions":
# If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile`
# set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds
# with atol=1e-8 and rtol=1e-5
if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"):
config.reference_compile = False
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
# as the others don't support outputted attentions
@@ -294,10 +298,6 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
def test_inputs_embeds(self):
pass
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)