ModernBERT bug fixes (#35404)
* bug fixes * organize imports * wrap cpu warning in reference_compile * Avoid needing repad_logits_with_grad, always repad with grads when training I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that? * Revert "Avoid needing repad_logits_with_grad, always repad with grads when training" This reverts commit cedcb4e89bcea199a1135a0933e71f534b656239. * Fix grammar: keep -> keeps * Propagate grammar fix with modular_model_converter --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
This commit is contained in:
@@ -505,7 +505,7 @@
|
|||||||
- local: model_doc/mobilebert
|
- local: model_doc/mobilebert
|
||||||
title: MobileBERT
|
title: MobileBERT
|
||||||
- local: model_doc/modernbert
|
- local: model_doc/modernbert
|
||||||
title: ModernBert
|
title: ModernBERT
|
||||||
- local: model_doc/mpnet
|
- local: model_doc/mpnet
|
||||||
title: MPNet
|
title: MPNet
|
||||||
- local: model_doc/mpt
|
- local: model_doc/mpt
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
-->
|
-->
|
||||||
|
|
||||||
# ModernBert
|
# ModernBERT
|
||||||
|
|
||||||
<div class="flex flex-wrap space-x-1">
|
<div class="flex flex-wrap space-x-1">
|
||||||
<a href="https://huggingface.co/models?filter=modernbert">
|
<a href="https://huggingface.co/models?filter=modernbert">
|
||||||
@@ -27,7 +27,7 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
||||||
|
|
||||||
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
|
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,9 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||||
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||||
be faster in some scenarios.
|
be faster in some scenarios.
|
||||||
|
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
|
||||||
|
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
|
||||||
|
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -164,6 +167,7 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
sparse_prediction=False,
|
sparse_prediction=False,
|
||||||
sparse_pred_ignore_index=-100,
|
sparse_pred_ignore_index=-100,
|
||||||
reference_compile=None,
|
reference_compile=None,
|
||||||
|
repad_logits_with_grad=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -203,6 +207,7 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
self.sparse_prediction = sparse_prediction
|
self.sparse_prediction = sparse_prediction
|
||||||
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
||||||
self.reference_compile = reference_compile
|
self.reference_compile = reference_compile
|
||||||
|
self.repad_logits_with_grad = repad_logits_with_grad
|
||||||
|
|
||||||
if self.classifier_pooling not in ["cls", "mean"]:
|
if self.classifier_pooling not in ["cls", "mean"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -632,12 +633,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
):
|
):
|
||||||
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
||||||
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
||||||
|
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||||
|
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||||
if config._attn_implementation_internal is None:
|
if config._attn_implementation_internal is None:
|
||||||
config._attn_implementation_internal = "flash_attention_2"
|
config._attn_implementation_internal = "flash_attention_2"
|
||||||
try:
|
try:
|
||||||
return cls._check_and_enable_flash_attn_2(
|
return cls._check_and_enable_flash_attn_2(
|
||||||
config,
|
config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch.float16,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
hard_check_only=False,
|
hard_check_only=False,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
@@ -647,7 +650,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
return super()._autoset_attn_implementation(
|
return super()._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
use_flash_attention_2=use_flash_attention_2,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch.float16,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
)
|
)
|
||||||
@@ -672,6 +675,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.config.reference_compile = False
|
self.config.reference_compile = False
|
||||||
|
|
||||||
|
if self.device.type == "cpu":
|
||||||
|
if self.config.reference_compile:
|
||||||
|
logger.warning_once(
|
||||||
|
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
||||||
|
"Falling back to non-compiled mode."
|
||||||
|
)
|
||||||
|
self.config.reference_compile = False
|
||||||
|
|
||||||
if self.config.reference_compile is None:
|
if self.config.reference_compile is None:
|
||||||
self.config.reference_compile = is_triton_available()
|
self.config.reference_compile = is_triton_available()
|
||||||
|
|
||||||
@@ -763,8 +774,8 @@ def _pad_modernbert_output(
|
|||||||
MODERNBERT_INPUTS_DOCSTRING = r"""
|
MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
|
||||||
it.
|
by default should you provide it.
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
@@ -790,7 +801,7 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
|||||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||||
far-away tokens in the local attention layers.
|
far-away tokens in the local attention layers when not using Flash Attention.
|
||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
@@ -805,11 +816,11 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
|||||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||||
max_seqlen (`int`, *optional*):
|
max_seqlen (`int`, *optional*):
|
||||||
Maximum sequence length in the batch. Used to pad the output tensors.
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||||
batch_size (`int`, *optional*):
|
batch_size (`int`, *optional*):
|
||||||
Batch size of the input sequences. Used to pad the output tensors.
|
Batch size of the input sequences. Used to pad the output tensors.
|
||||||
seq_len (`int`, *optional*):
|
seq_len (`int`, *optional*):
|
||||||
Sequence length of the input sequences. Used to pad the output tensors.
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@@ -1128,8 +1139,9 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|||||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||||
|
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
with torch.no_grad():
|
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||||
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,)
|
output = (logits,)
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Dict, Literal, Optional, Tuple, Union
|
from typing import Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -141,6 +142,9 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||||
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||||
be faster in some scenarios.
|
be faster in some scenarios.
|
||||||
|
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
|
||||||
|
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
|
||||||
|
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -196,6 +200,7 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
sparse_prediction=False,
|
sparse_prediction=False,
|
||||||
sparse_pred_ignore_index=-100,
|
sparse_pred_ignore_index=-100,
|
||||||
reference_compile=None,
|
reference_compile=None,
|
||||||
|
repad_logits_with_grad=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -235,6 +240,7 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
self.sparse_prediction = sparse_prediction
|
self.sparse_prediction = sparse_prediction
|
||||||
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
||||||
self.reference_compile = reference_compile
|
self.reference_compile = reference_compile
|
||||||
|
self.repad_logits_with_grad = repad_logits_with_grad
|
||||||
|
|
||||||
if self.classifier_pooling not in ["cls", "mean"]:
|
if self.classifier_pooling not in ["cls", "mean"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -857,12 +863,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
):
|
):
|
||||||
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
||||||
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
||||||
|
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||||
|
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||||
if config._attn_implementation_internal is None:
|
if config._attn_implementation_internal is None:
|
||||||
config._attn_implementation_internal = "flash_attention_2"
|
config._attn_implementation_internal = "flash_attention_2"
|
||||||
try:
|
try:
|
||||||
return cls._check_and_enable_flash_attn_2(
|
return cls._check_and_enable_flash_attn_2(
|
||||||
config,
|
config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch.float16,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
hard_check_only=False,
|
hard_check_only=False,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
@@ -872,7 +880,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
return super()._autoset_attn_implementation(
|
return super()._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
use_flash_attention_2=use_flash_attention_2,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch.float16,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
)
|
)
|
||||||
@@ -897,6 +905,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.config.reference_compile = False
|
self.config.reference_compile = False
|
||||||
|
|
||||||
|
if self.device.type == "cpu":
|
||||||
|
if self.config.reference_compile:
|
||||||
|
logger.warning_once(
|
||||||
|
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
||||||
|
"Falling back to non-compiled mode."
|
||||||
|
)
|
||||||
|
self.config.reference_compile = False
|
||||||
|
|
||||||
if self.config.reference_compile is None:
|
if self.config.reference_compile is None:
|
||||||
self.config.reference_compile = is_triton_available()
|
self.config.reference_compile = is_triton_available()
|
||||||
|
|
||||||
@@ -916,8 +932,8 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
MODERNBERT_INPUTS_DOCSTRING = r"""
|
MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
|
||||||
it.
|
by default should you provide it.
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
@@ -943,7 +959,7 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
|||||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||||
far-away tokens in the local attention layers.
|
far-away tokens in the local attention layers when not using Flash Attention.
|
||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
@@ -958,11 +974,11 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
|||||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||||
max_seqlen (`int`, *optional*):
|
max_seqlen (`int`, *optional*):
|
||||||
Maximum sequence length in the batch. Used to pad the output tensors.
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||||
batch_size (`int`, *optional*):
|
batch_size (`int`, *optional*):
|
||||||
Batch size of the input sequences. Used to pad the output tensors.
|
Batch size of the input sequences. Used to pad the output tensors.
|
||||||
seq_len (`int`, *optional*):
|
seq_len (`int`, *optional*):
|
||||||
Sequence length of the input sequences. Used to pad the output tensors.
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@@ -1281,8 +1297,9 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|||||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||||
|
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
with torch.no_grad():
|
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||||
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,)
|
output = (logits,)
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|||||||
Reference in New Issue
Block a user