Add inputs_embeds param to ModernBertModel (#35373)
* update modular_modernbert -- add inputs_embeds param to ModernBertModel * Fix implementation issues; extend to other classes; docstring First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented. I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes. I also introduced an error if input_ids and input_embeds are both or neither provided. Lastly, I fixed an issue with device being based solely on input_ids with attention_mask. * Propagate inputs_embeds to ModernBertForMaskedLM correctly Also reintroduce inputs_embeds test --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
This commit is contained in:
@@ -146,7 +146,11 @@ class ModernBertModelTester:
|
||||
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
|
||||
# that compilation doesn't work. Users can then set compile=False when loading the model,
|
||||
# much like here. We're testing whether it works once they've done that.
|
||||
if test_name == "test_retain_grad_hidden_states_attentions":
|
||||
|
||||
# If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile`
|
||||
# set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds
|
||||
# with atol=1e-8 and rtol=1e-5
|
||||
if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"):
|
||||
config.reference_compile = False
|
||||
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
|
||||
# as the others don't support outputted attentions
|
||||
@@ -294,10 +298,6 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user