diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 36452aabd4..dbc8595e13 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -177,6 +177,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 00b02b1431..e7ee3b1462 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -374,6 +374,102 @@ class StableLmAttention(nn.Module): return attn_output, attn_weights, past_key_value +class StableLmSdpaAttention(StableLmAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + class StableLmFlashAttention2(StableLmAttention): """ StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays @@ -574,6 +670,7 @@ class StableLmFlashAttention2(StableLmAttention): ATTENTION_CLASSES = { "eager": StableLmAttention, + "sdpa": StableLmSdpaAttention, "flash_attention_2": StableLmFlashAttention2, } @@ -680,6 +777,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.initializer_range @@ -858,6 +956,11 @@ class StableLmModel(StableLmPreTrainedModel): if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + # for output_attentions case used fallback to eager attention realization + elif self._attn_implementation == "sdpa" and not output_attentions: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 8ff8eeffc4..2497dfc3ee 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, require_torch, + require_torch_sdpa, slow, torch_device, ) @@ -431,3 +432,65 @@ class StableLmModelIntegrationTest(unittest.TestCase): input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist()) + + # Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate with Llama->StableLm,saibo/llama-1B->stabilityai/stablelm-3b-4e1t + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + max_new_tokens = 30 + + tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t") + + model_sdpa = StableLmForCausalLM.from_pretrained( + "stabilityai/stablelm-3b-4e1t", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = StableLmForCausalLM.from_pretrained( + "stabilityai/stablelm-3b-4e1t", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + texts = [ + "hi here's a longer context, getting longer and", + "Hello this is a very long sentence my friend, very long for real", + "Today I am in Paris and", + ] + + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + + with self.subTest(f"{padding_side}"): + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + )