diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a76a6fead7..7b6476bd61 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import math from contextlib import nullcontext from typing import Optional, Union @@ -459,20 +460,21 @@ class ModernBertAttention(nn.Module): if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) + rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta + max_position_embeddings = config.local_attention else: self.local_attention = (-1, -1) - - max_position_embeddings = config.max_position_embeddings - if self.local_attention != (-1, -1): - rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta - max_position_embeddings = config.local_attention + max_position_embeddings = config.max_position_embeddings + rope_theta = config.global_rope_theta if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config) + config_copy = copy.deepcopy(config) + config_copy.rope_theta = rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() @@ -611,7 +613,9 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def set_attention_implementation(self, attn_implementation: Union[dict, str]): + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: """ Checks and dispatches to hhe requested attention implementation. """ @@ -620,16 +624,17 @@ class ModernBertPreTrainedModel(PreTrainedModel): # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. - requested_attn_implementation = self._check_attn_implementation(attn_implementation) try: attn_implementation = ( "flash_attention_2" - if requested_attn_implementation is None and self._flash_attn_2_can_dispatch() + if attn_implementation is None and self._flash_attn_2_can_dispatch() else attn_implementation ) except (ValueError, ImportError): pass - return super().set_attention_implementation(attn_implementation=attn_implementation) + return super()._check_and_adjust_attn_implementation( + attn_implementation=attn_implementation, is_init_check=is_init_check + ) def _maybe_set_compile(self): if self.config.reference_compile is False: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 254b5d3163..3648e30ac9 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import math from contextlib import nullcontext from typing import Literal, Optional, Union @@ -659,20 +660,21 @@ class ModernBertAttention(nn.Module): if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) + rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta + max_position_embeddings = config.local_attention else: self.local_attention = (-1, -1) - - max_position_embeddings = config.max_position_embeddings - if self.local_attention != (-1, -1): - rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta - max_position_embeddings = config.local_attention + max_position_embeddings = config.max_position_embeddings + rope_theta = config.global_rope_theta if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config) + config_copy = copy.deepcopy(config) + config_copy.rope_theta = rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() @@ -811,7 +813,9 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def set_attention_implementation(self, attn_implementation: Union[dict, str]): + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: """ Checks and dispatches to hhe requested attention implementation. """ @@ -820,16 +824,17 @@ class ModernBertPreTrainedModel(PreTrainedModel): # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. - requested_attn_implementation = self._check_attn_implementation(attn_implementation) try: attn_implementation = ( "flash_attention_2" - if requested_attn_implementation is None and self._flash_attn_2_can_dispatch() + if attn_implementation is None and self._flash_attn_2_can_dispatch() else attn_implementation ) except (ValueError, ImportError): pass - return super().set_attention_implementation(attn_implementation=attn_implementation) + return super()._check_and_adjust_attn_implementation( + attn_implementation=attn_implementation, is_init_check=is_init_check + ) def _maybe_set_compile(self): if self.config.reference_compile is False: diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 187fbcb838..a6bb02e8b8 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -375,6 +375,16 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa config_dict = json.load(f) self.assertNotIn("reference_compile", config_dict) + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + def test_flash_attention_dispatches_by_defaul(self): + "ModernBert should dispatch to FA2 by default, not SDPA" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config=config) + self.assertTrue(model.config._attn_implementation == "flash_attention_2") + @require_torch class ModernBertModelIntegrationTest(unittest.TestCase):