Add inputs_embeds param to ModernBertModel (#35373)
* update modular_modernbert -- add inputs_embeds param to ModernBertModel * Fix implementation issues; extend to other classes; docstring First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented. I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes. I also introduced an error if input_ids and input_embeds are both or neither provided. Lastly, I fixed an issue with device being based solely on input_ids with attention_mask. * Propagate inputs_embeds to ModernBertForMaskedLM correctly Also reintroduce inputs_embeds test --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
This commit is contained in:
@@ -205,12 +205,17 @@ class ModernBertEmbeddings(nn.Module):
|
||||
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
||||
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
self.compiled_embeddings(input_ids)
|
||||
if self.config.reference_compile
|
||||
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
)
|
||||
def forward(
|
||||
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = self.drop(self.norm(inputs_embeds))
|
||||
else:
|
||||
hidden_states = (
|
||||
self.compiled_embeddings(input_ids)
|
||||
if self.config.reference_compile
|
||||
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -791,6 +796,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
@@ -842,10 +851,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
@@ -861,35 +871,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
self._maybe_set_compile()
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
|
||||
if input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
|
||||
if batch_size is None and seq_len is None:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
if inputs_embeds is not None:
|
||||
batch_size, seq_len = inputs_embeds.shape[:2]
|
||||
else:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
repad = False
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if indices is None and cu_seqlens is None and max_seqlen is None:
|
||||
repad = True
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask
|
||||
if inputs_embeds is None:
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=inputs_embeds, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
||||
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
|
||||
attention_mask, sliding_window_mask = self._update_attention_mask(
|
||||
attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1025,10 +1049,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1045,19 +1070,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if indices is None and cu_seqlens is None and max_seqlen is None:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
if batch_size is None and seq_len is None:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, seq_len = inputs_embeds.shape[:2]
|
||||
else:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
if inputs_embeds is None:
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
)
|
||||
else:
|
||||
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
@@ -1130,10 +1168,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1155,10 +1194,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
@@ -1241,10 +1281,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1263,10 +1304,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
|
||||
@@ -464,12 +464,17 @@ class ModernBertEmbeddings(nn.Module):
|
||||
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
||||
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
self.compiled_embeddings(input_ids)
|
||||
if self.config.reference_compile
|
||||
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
)
|
||||
def forward(
|
||||
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = self.drop(self.norm(inputs_embeds))
|
||||
else:
|
||||
hidden_states = (
|
||||
self.compiled_embeddings(input_ids)
|
||||
if self.config.reference_compile
|
||||
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -944,6 +949,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
@@ -995,10 +1004,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
@@ -1014,35 +1024,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
self._maybe_set_compile()
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
|
||||
if input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
|
||||
if batch_size is None and seq_len is None:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
if inputs_embeds is not None:
|
||||
batch_size, seq_len = inputs_embeds.shape[:2]
|
||||
else:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
repad = False
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if indices is None and cu_seqlens is None and max_seqlen is None:
|
||||
repad = True
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask
|
||||
if inputs_embeds is None:
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
||||
inputs=inputs_embeds, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
||||
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
|
||||
attention_mask, sliding_window_mask = self._update_attention_mask(
|
||||
attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1178,10 +1202,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1198,19 +1223,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if indices is None and cu_seqlens is None and max_seqlen is None:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
if batch_size is None and seq_len is None:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, seq_len = inputs_embeds.shape[:2]
|
||||
else:
|
||||
batch_size, seq_len = input_ids.shape[:2]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
if inputs_embeds is None:
|
||||
with torch.no_grad():
|
||||
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
)
|
||||
else:
|
||||
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
||||
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
@@ -1283,10 +1321,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1308,10 +1347,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
@@ -1394,10 +1434,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
@@ -1416,10 +1457,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
|
||||
@@ -146,7 +146,11 @@ class ModernBertModelTester:
|
||||
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
|
||||
# that compilation doesn't work. Users can then set compile=False when loading the model,
|
||||
# much like here. We're testing whether it works once they've done that.
|
||||
if test_name == "test_retain_grad_hidden_states_attentions":
|
||||
|
||||
# If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile`
|
||||
# set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds
|
||||
# with atol=1e-8 and rtol=1e-5
|
||||
if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"):
|
||||
config.reference_compile = False
|
||||
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
|
||||
# as the others don't support outputted attentions
|
||||
@@ -294,10 +298,6 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user