Add inputs_embeds param to ModernBertModel (#35373)

* update modular_modernbert -- add inputs_embeds param to ModernBertModel

* Fix implementation issues; extend to other classes; docstring

First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented.

I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes.

I also introduced an error if input_ids and input_embeds are both or neither provided.

Lastly, I fixed an issue with device being based solely on input_ids with attention_mask.

* Propagate inputs_embeds to ModernBertForMaskedLM correctly

Also reintroduce inputs_embeds test

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
This commit is contained in:
Jack Morris
2025-01-09 08:17:26 -05:00
committed by GitHub
parent 1b2f942af7
commit 832c6191ed
3 changed files with 141 additions and 57 deletions

View File

@@ -205,12 +205,17 @@ class ModernBertEmbeddings(nn.Module):
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.drop(self.norm(self.tok_embeddings(input_ids)))
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
def forward(
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
return hidden_states
@@ -791,6 +796,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
@@ -842,10 +851,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
@@ -861,35 +871,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
self._maybe_set_compile()
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if batch_size is None and seq_len is None:
batch_size, seq_len = input_ids.shape[:2]
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
hidden_states = self.embeddings(input_ids)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
for encoder_layer in self.layers:
if output_hidden_states:
@@ -1025,10 +1049,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1045,19 +1070,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
batch_size, seq_len = input_ids.shape[:2]
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
@@ -1130,10 +1168,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1155,10 +1194,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
self._maybe_set_compile()
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
@@ -1241,10 +1281,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1263,10 +1304,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
self._maybe_set_compile()
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,

View File

@@ -464,12 +464,17 @@ class ModernBertEmbeddings(nn.Module):
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.drop(self.norm(self.tok_embeddings(input_ids)))
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
def forward(
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
return hidden_states
@@ -944,6 +949,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
@@ -995,10 +1004,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
@@ -1014,35 +1024,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
self._maybe_set_compile()
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if batch_size is None and seq_len is None:
batch_size, seq_len = input_ids.shape[:2]
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
hidden_states = self.embeddings(input_ids)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
for encoder_layer in self.layers:
if output_hidden_states:
@@ -1178,10 +1202,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1198,19 +1223,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
batch_size, seq_len = input_ids.shape[:2]
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
@@ -1283,10 +1321,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1308,10 +1347,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
self._maybe_set_compile()
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
@@ -1394,10 +1434,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
@@ -1416,10 +1457,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
self._maybe_set_compile()
outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,

View File

@@ -146,7 +146,11 @@ class ModernBertModelTester:
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
# that compilation doesn't work. Users can then set compile=False when loading the model,
# much like here. We're testing whether it works once they've done that.
if test_name == "test_retain_grad_hidden_states_attentions":
# If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile`
# set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds
# with atol=1e-8 and rtol=1e-5
if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"):
config.reference_compile = False
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
# as the others don't support outputted attentions
@@ -294,10 +298,6 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
def test_inputs_embeds(self):
pass
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)