committed by
GitHub
parent
379209b603
commit
1ad216bd7d
@@ -19,6 +19,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@@ -459,20 +460,21 @@ class ModernBertAttention(nn.Module):
|
|||||||
|
|
||||||
if layer_id % config.global_attn_every_n_layers != 0:
|
if layer_id % config.global_attn_every_n_layers != 0:
|
||||||
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
||||||
|
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
|
||||||
|
max_position_embeddings = config.local_attention
|
||||||
else:
|
else:
|
||||||
self.local_attention = (-1, -1)
|
self.local_attention = (-1, -1)
|
||||||
|
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
if self.local_attention != (-1, -1):
|
rope_theta = config.global_rope_theta
|
||||||
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
|
|
||||||
max_position_embeddings = config.local_attention
|
|
||||||
|
|
||||||
if config._attn_implementation == "flash_attention_2":
|
if config._attn_implementation == "flash_attention_2":
|
||||||
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
||||||
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
config_copy = copy.deepcopy(config)
|
||||||
|
config_copy.rope_theta = rope_theta
|
||||||
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
||||||
|
|
||||||
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||||
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||||
@@ -611,7 +613,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
def _check_and_adjust_attn_implementation(
|
||||||
|
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Checks and dispatches to hhe requested attention implementation.
|
Checks and dispatches to hhe requested attention implementation.
|
||||||
"""
|
"""
|
||||||
@@ -620,16 +624,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||||
|
|
||||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
|
||||||
try:
|
try:
|
||||||
attn_implementation = (
|
attn_implementation = (
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
|
if attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||||
else attn_implementation
|
else attn_implementation
|
||||||
)
|
)
|
||||||
except (ValueError, ImportError):
|
except (ValueError, ImportError):
|
||||||
pass
|
pass
|
||||||
return super().set_attention_implementation(attn_implementation=attn_implementation)
|
return super()._check_and_adjust_attn_implementation(
|
||||||
|
attn_implementation=attn_implementation, is_init_check=is_init_check
|
||||||
|
)
|
||||||
|
|
||||||
def _maybe_set_compile(self):
|
def _maybe_set_compile(self):
|
||||||
if self.config.reference_compile is False:
|
if self.config.reference_compile is False:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
@@ -659,20 +660,21 @@ class ModernBertAttention(nn.Module):
|
|||||||
|
|
||||||
if layer_id % config.global_attn_every_n_layers != 0:
|
if layer_id % config.global_attn_every_n_layers != 0:
|
||||||
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
||||||
|
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
|
||||||
|
max_position_embeddings = config.local_attention
|
||||||
else:
|
else:
|
||||||
self.local_attention = (-1, -1)
|
self.local_attention = (-1, -1)
|
||||||
|
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
if self.local_attention != (-1, -1):
|
rope_theta = config.global_rope_theta
|
||||||
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
|
|
||||||
max_position_embeddings = config.local_attention
|
|
||||||
|
|
||||||
if config._attn_implementation == "flash_attention_2":
|
if config._attn_implementation == "flash_attention_2":
|
||||||
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
||||||
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
|
config_copy = copy.deepcopy(config)
|
||||||
|
config_copy.rope_theta = rope_theta
|
||||||
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
||||||
|
|
||||||
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||||
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||||
@@ -811,7 +813,9 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
def _check_and_adjust_attn_implementation(
|
||||||
|
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Checks and dispatches to hhe requested attention implementation.
|
Checks and dispatches to hhe requested attention implementation.
|
||||||
"""
|
"""
|
||||||
@@ -820,16 +824,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||||
|
|
||||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
|
||||||
try:
|
try:
|
||||||
attn_implementation = (
|
attn_implementation = (
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
|
if attn_implementation is None and self._flash_attn_2_can_dispatch()
|
||||||
else attn_implementation
|
else attn_implementation
|
||||||
)
|
)
|
||||||
except (ValueError, ImportError):
|
except (ValueError, ImportError):
|
||||||
pass
|
pass
|
||||||
return super().set_attention_implementation(attn_implementation=attn_implementation)
|
return super()._check_and_adjust_attn_implementation(
|
||||||
|
attn_implementation=attn_implementation, is_init_check=is_init_check
|
||||||
|
)
|
||||||
|
|
||||||
def _maybe_set_compile(self):
|
def _maybe_set_compile(self):
|
||||||
if self.config.reference_compile is False:
|
if self.config.reference_compile is False:
|
||||||
|
|||||||
@@ -375,6 +375,16 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
self.assertNotIn("reference_compile", config_dict)
|
self.assertNotIn("reference_compile", config_dict)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
def test_flash_attention_dispatches_by_defaul(self):
|
||||||
|
"ModernBert should dispatch to FA2 by default, not SDPA"
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=config)
|
||||||
|
self.assertTrue(model.config._attn_implementation == "flash_attention_2")
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModernBertModelIntegrationTest(unittest.TestCase):
|
class ModernBertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user