[modenbert] fix regression (#39750)

* fix regression

* add FA2 test
This commit is contained in:
Raushan Turganbay
2025-07-29 16:58:59 +02:00
committed by GitHub
parent 379209b603
commit 1ad216bd7d
3 changed files with 40 additions and 20 deletions

View File

@@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from contextlib import nullcontext
from typing import Optional, Union
@@ -459,20 +460,21 @@ class ModernBertAttention(nn.Module):
if layer_id % config.global_attn_every_n_layers != 0:
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
max_position_embeddings = config.local_attention
else:
self.local_attention = (-1, -1)
max_position_embeddings = config.max_position_embeddings
if self.local_attention != (-1, -1):
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
max_position_embeddings = config.local_attention
rope_theta = config.global_rope_theta
if config._attn_implementation == "flash_attention_2":
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
config_copy = copy.deepcopy(config)
config_copy.rope_theta = rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -611,7 +613,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""
Checks and dispatches to hhe requested attention implementation.
"""
@@ -620,16 +624,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
try:
attn_implementation = (
"flash_attention_2"
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
if attn_implementation is None and self._flash_attn_2_can_dispatch()
else attn_implementation
)
except (ValueError, ImportError):
pass
return super().set_attention_implementation(attn_implementation=attn_implementation)
return super()._check_and_adjust_attn_implementation(
attn_implementation=attn_implementation, is_init_check=is_init_check
)
def _maybe_set_compile(self):
if self.config.reference_compile is False:

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from contextlib import nullcontext
from typing import Literal, Optional, Union
@@ -659,20 +660,21 @@ class ModernBertAttention(nn.Module):
if layer_id % config.global_attn_every_n_layers != 0:
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
max_position_embeddings = config.local_attention
else:
self.local_attention = (-1, -1)
max_position_embeddings = config.max_position_embeddings
if self.local_attention != (-1, -1):
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
max_position_embeddings = config.local_attention
rope_theta = config.global_rope_theta
if config._attn_implementation == "flash_attention_2":
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
config_copy = copy.deepcopy(config)
config_copy.rope_theta = rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -811,7 +813,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""
Checks and dispatches to hhe requested attention implementation.
"""
@@ -820,16 +824,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
try:
attn_implementation = (
"flash_attention_2"
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
if attn_implementation is None and self._flash_attn_2_can_dispatch()
else attn_implementation
)
except (ValueError, ImportError):
pass
return super().set_attention_implementation(attn_implementation=attn_implementation)
return super()._check_and_adjust_attn_implementation(
attn_implementation=attn_implementation, is_init_check=is_init_check
)
def _maybe_set_compile(self):
if self.config.reference_compile is False:

View File

@@ -375,6 +375,16 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_dict = json.load(f)
self.assertNotIn("reference_compile", config_dict)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
def test_flash_attention_dispatches_by_defaul(self):
"ModernBert should dispatch to FA2 by default, not SDPA"
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
self.assertTrue(model.config._attn_implementation == "flash_attention_2")
@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):