support SDPA Attention in stablelm (#29106)
* support SDPA Attention in stablelm * add integration test * add fallback for output_attentions * Update src/transformers/models/stablelm/modeling_stablelm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/models/stablelm/test_modeling_stablelm.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/stablelm/modeling_stablelm.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * handle non-contiguous states --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -177,6 +177,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||||
|
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -374,6 +374,102 @@ class StableLmAttention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class StableLmSdpaAttention(StableLmAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
|
# Partial rotary embedding
|
||||||
|
query_rot, query_pass = (
|
||||||
|
query_states[..., : self.rotary_emb.dim],
|
||||||
|
query_states[..., self.rotary_emb.dim :],
|
||||||
|
)
|
||||||
|
key_rot, key_pass = (
|
||||||
|
key_states[..., : self.rotary_emb.dim],
|
||||||
|
key_states[..., self.rotary_emb.dim :],
|
||||||
|
)
|
||||||
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||||
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||||
|
|
||||||
|
# [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Specific to RoPE models with partial rotation
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# Repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
||||||
|
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||||
|
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class StableLmFlashAttention2(StableLmAttention):
|
class StableLmFlashAttention2(StableLmAttention):
|
||||||
"""
|
"""
|
||||||
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
||||||
@@ -574,6 +670,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
|||||||
|
|
||||||
ATTENTION_CLASSES = {
|
ATTENTION_CLASSES = {
|
||||||
"eager": StableLmAttention,
|
"eager": StableLmAttention,
|
||||||
|
"sdpa": StableLmSdpaAttention,
|
||||||
"flash_attention_2": StableLmFlashAttention2,
|
"flash_attention_2": StableLmFlashAttention2,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,6 +777,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@@ -858,6 +956,11 @@ class StableLmModel(StableLmPreTrainedModel):
|
|||||||
if self._attn_implementation == "flash_attention_2":
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
# for output_attentions case used fallback to eager attention realization
|
||||||
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers.testing_utils import (
|
|||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -431,3 +432,65 @@ class StableLmModelIntegrationTest(unittest.TestCase):
|
|||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
|
||||||
|
|
||||||
|
# Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate with Llama->StableLm,saibo/llama-1B->stabilityai/stablelm-3b-4e1t
|
||||||
|
@require_torch_sdpa
|
||||||
|
@slow
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
"""
|
||||||
|
Overwritting the common test as the test is flaky on tiny models
|
||||||
|
"""
|
||||||
|
max_new_tokens = 30
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t")
|
||||||
|
|
||||||
|
model_sdpa = StableLmForCausalLM.from_pretrained(
|
||||||
|
"stabilityai/stablelm-3b-4e1t",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||||
|
|
||||||
|
model_eager = StableLmForCausalLM.from_pretrained(
|
||||||
|
"stabilityai/stablelm-3b-4e1t",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
attn_implementation="eager",
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||||
|
|
||||||
|
for name, submodule in model_eager.named_modules():
|
||||||
|
if "SdpaAttention" in submodule.__class__.__name__:
|
||||||
|
raise ValueError("The eager model should not have SDPA attention layers")
|
||||||
|
|
||||||
|
has_sdpa = False
|
||||||
|
for name, submodule in model_sdpa.named_modules():
|
||||||
|
if "SdpaAttention" in submodule.__class__.__name__:
|
||||||
|
has_sdpa = True
|
||||||
|
break
|
||||||
|
if not has_sdpa:
|
||||||
|
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"hi here's a longer context, getting longer and",
|
||||||
|
"Hello this is a very long sentence my friend, very long for real",
|
||||||
|
"Today I am in Paris and",
|
||||||
|
]
|
||||||
|
|
||||||
|
for padding_side in ["left", "right"]:
|
||||||
|
tokenizer.padding_side = padding_side
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
|
||||||
|
with self.subTest(f"{padding_side}"):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
res_eager,
|
||||||
|
res_sdpa,
|
||||||
|
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user