committed by
GitHub
parent
379209b603
commit
1ad216bd7d
@@ -19,6 +19,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union
|
||||
@@ -459,20 +460,21 @@ class ModernBertAttention(nn.Module):
|
||||
|
||||
if layer_id % config.global_attn_every_n_layers != 0:
|
||||
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
||||
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
|
||||
max_position_embeddings = config.local_attention
|
||||
else:
|
||||
self.local_attention = (-1, -1)
|
||||
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
if self.local_attention != (-1, -1):
|
||||
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
|
||||
max_position_embeddings = config.local_attention
|
||||
rope_theta = config.global_rope_theta
|
||||
|
||||
if config._attn_implementation == "flash_attention_2":
|
||||
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
||||
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
||||
config_copy = copy.deepcopy(config)
|
||||
config_copy.rope_theta = rope_theta
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
||||
|
||||
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||
@@ -611,7 +613,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
||||
def _check_and_adjust_attn_implementation(
|
||||
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Checks and dispatches to hhe requested attention implementation.
|
||||
"""
|
||||
@@ -620,16 +624,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
try:
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||
if attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||
else attn_implementation
|
||||
)
|
||||
except (ValueError, ImportError):
|
||||
pass
|
||||
return super().set_attention_implementation(attn_implementation=attn_implementation)
|
||||
return super()._check_and_adjust_attn_implementation(
|
||||
attn_implementation=attn_implementation, is_init_check=is_init_check
|
||||
)
|
||||
|
||||
def _maybe_set_compile(self):
|
||||
if self.config.reference_compile is False:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Union
|
||||
@@ -659,20 +660,21 @@ class ModernBertAttention(nn.Module):
|
||||
|
||||
if layer_id % config.global_attn_every_n_layers != 0:
|
||||
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
||||
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
|
||||
max_position_embeddings = config.local_attention
|
||||
else:
|
||||
self.local_attention = (-1, -1)
|
||||
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
if self.local_attention != (-1, -1):
|
||||
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
|
||||
max_position_embeddings = config.local_attention
|
||||
rope_theta = config.global_rope_theta
|
||||
|
||||
if config._attn_implementation == "flash_attention_2":
|
||||
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
||||
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
||||
config_copy = copy.deepcopy(config)
|
||||
config_copy.rope_theta = rope_theta
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
||||
|
||||
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||
@@ -811,7 +813,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
||||
def _check_and_adjust_attn_implementation(
|
||||
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Checks and dispatches to hhe requested attention implementation.
|
||||
"""
|
||||
@@ -820,16 +824,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
try:
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||
if attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||
else attn_implementation
|
||||
)
|
||||
except (ValueError, ImportError):
|
||||
pass
|
||||
return super().set_attention_implementation(attn_implementation=attn_implementation)
|
||||
return super()._check_and_adjust_attn_implementation(
|
||||
attn_implementation=attn_implementation, is_init_check=is_init_check
|
||||
)
|
||||
|
||||
def _maybe_set_compile(self):
|
||||
if self.config.reference_compile is False:
|
||||
|
||||
@@ -375,6 +375,16 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
config_dict = json.load(f)
|
||||
self.assertNotIn("reference_compile", config_dict)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_flash_attention_dispatches_by_defaul(self):
|
||||
"ModernBert should dispatch to FA2 by default, not SDPA"
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
self.assertTrue(model.config._attn_implementation == "flash_attention_2")
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModernBertModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user