diff --git a/docs/source/en/model_doc/gpt_neox.md b/docs/source/en/model_doc/gpt_neox.md index fd105a3e82..1319f2e93c 100644 --- a/docs/source/en/model_doc/gpt_neox.md +++ b/docs/source/en/model_doc/gpt_neox.md @@ -95,6 +95,68 @@ Below is an expected speedup diagram that compares pure inference time between t + +## Using Scaled Dot Product Attention (SDPA) +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +```python +from transformers import GPTNeoXForCausalLM +model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="sdpa") +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with +[pythia-410m-deduped](https://huggingface.co/EleutherAI/pythia-410m-deduped), we saw the +following speedups during training and inference. + +### Training +| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) | +|-----------:|-----------:|---------------------------:|-----------------------------:|------------:|--------------------:|-------------------:|------------------:| +| 1 | 128 | 0.024 | 0.019 | 28.945 | 1789.95 | 1789.95 | 0 | +| 1 | 256 | 0.039 | 0.031 | 23.18 | 1845.83 | 1844.84 | 0.053 | +| 1 | 512 | 0.08 | 0.055 | 45.524 | 2278.38 | 1953.76 | 16.615 | +| 1 | 1024 | 0.19 | 0.102 | 86.777 | 4772.36 | 2408.35 | 98.159 | +| 1 | 2048 | 0.565 | 0.204 | 177.098 | 13484.1 | 3882.01 | 247.348 | +| 2 | 128 | 0.037 | 0.032 | 15.121 | 1843.86 | 1844.78 | -0.05 | +| 2 | 256 | 0.067 | 0.055 | 21.706 | 1999.72 | 1951.67 | 2.462 | +| 2 | 512 | 0.144 | 0.096 | 50.046 | 3613.16 | 2406.77 | 50.125 | +| 2 | 1024 | 0.366 | 0.193 | 89.666 | 8707.55 | 3878.86 | 124.487 | +| 2 | 2048 | OOM | 0.379 | / | OOM | 6825.13 | SDPA does not OOM | +| 4 | 128 | 0.06 | 0.054 | 11.539 | 1947.6 | 1952.06 | -0.228 | +| 4 | 256 | 0.119 | 0.093 | 28.072 | 3008.39 | 2405.99 | 25.038 | +| 4 | 512 | 0.275 | 0.187 | 47.145 | 6290.58 | 3877.29 | 62.242 | +| 4 | 1024 | OOM | 0.36 | / | OOM | 6821.98 | SDPA does not OOM | +| 4 | 2048 | OOM | 0.731 | / | OOM | 12705.1 | SDPA does not OOM | + +### Inference +| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) | +|--------------:|-------------:|--------------------------------:|-------------------------------:|---------------:|------------------:|----------------:|-----------------:| +| 1 | 128 | 6.569 | 5.858 | 12.14 | 974.831 | 974.826 | 0 | +| 1 | 256 | 7.009 | 5.863 | 19.542 | 1029.01 | 1028.08 | 0.09 | +| 1 | 512 | 7.157 | 5.965 | 19.983 | 1137.54 | 1137.52 | 0.001 | +| 1 | 1024 | 7.523 | 6.506 | 15.637 | 1329.3 | 1329.26 | 0.003 | +| 1 | 2048 | 9.271 | 9.205 | 0.713 | 1752.47 | 1734.51 | 1.036 | +| 2 | 128 | 7.239 | 5.959 | 21.493 | 1044.8 | 1028.37 | 1.597 | +| 2 | 256 | 7.228 | 6.036 | 19.757 | 1167.32 | 1137.73 | 2.601 | +| 2 | 512 | 7.538 | 6.693 | 12.628 | 1352.93 | 1329.55 | 1.758 | +| 2 | 1024 | 8.916 | 8.632 | 3.291 | 1752.56 | 1734.62 | 1.034 | +| 2 | 2048 | 12.628 | 12.606 | 0.181 | 2558.72 | 2545.8 | 0.508 | +| 4 | 128 | 7.278 | 6.046 | 20.373 | 1168.41 | 1137.79 | 2.691 | +| 4 | 256 | 7.614 | 6.588 | 15.574 | 1353.1 | 1329.79 | 1.753 | +| 4 | 512 | 8.798 | 8.144 | 8.028 | 1752.76 | 1734.85 | 1.032 | +| 4 | 1024 | 11.765 | 11.303 | 4.09 | 2558.96 | 2546.04 | 0.508 | +| 4 | 2048 | 19.568 | 17.735 | 10.33 | 4175.5 | 4165.26 | 0.246 | + + ## Resources - [Causal language modeling task guide](../tasks/language_modeling) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index fb443a3ce1..8f48d27e2b 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index bde881226f..85ee61e7fe 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F @@ -29,6 +30,7 @@ from ...file_utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +39,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ...utils import get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from .configuration_gpt_neox import GPTNeoXConfig @@ -78,6 +80,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -162,7 +165,56 @@ class GPTNeoXAttention(nn.Module): layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - padding_mask: Optional[torch.Tensor] = None, + ): + # Apply attention-specific projections and rope + query, key, value, present = self._attn_projections_and_rope( + hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache + ) + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn_projections_and_rope( + self, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, ): has_layer_past = layer_past is not None @@ -204,43 +256,7 @@ class GPTNeoXAttention(nn.Module): value = torch.cat((past_value, value), dim=-2) present = (key, value) if use_cache else None - # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - # Reshape outputs - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) - attn_output = self.dense(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - @classmethod - def _split_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - # tensor: [bs, seq_len, hidden_size] - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(new_shape) - # -> [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3) - return tensor - - @classmethod - def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - # tensor [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3).contiguous() - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) - # -> [bs, seq_len, hidden_size] - return tensor + return query, key, value, present def _attn(self, query, key, value, attention_mask=None, head_mask=None): # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] @@ -319,48 +335,13 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention): use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ): - has_layer_past = layer_past is not None - - # Compute QKV - # Attention heads [batch, seq_len, hidden_size] - # --> [batch, seq_len, (np * 3 * head_size)] - qkv = self.query_key_value(hidden_states) - - # [batch, seq_len, (num_heads * 3 * head_size)] - # --> [batch, seq_len, num_heads, 3 * head_size] - new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) - qkv = qkv.view(*new_qkv_shape) - - # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] - query = qkv[..., : self.head_size].permute(0, 2, 1, 3) - key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) - value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + # Apply attention-specific projections and rope + query, key, value, present = self._attn_projections_and_rope( + hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache + ) query_length = query.shape[-2] - # Compute rotary embeddings on rotary_ndims - query_rot = query[..., : self.rotary_ndims] - query_pass = query[..., self.rotary_ndims :] - key_rot = key[..., : self.rotary_ndims] - key_pass = key[..., self.rotary_ndims :] - - # Compute token offset for rotary embeddings (when decoding) - seq_len = key.shape[-2] - if has_layer_past: - seq_len += layer_past[0].shape[-2] - cos, sin = self.rotary_emb(value, seq_len=seq_len) - query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - query = torch.cat((query, query_pass), dim=-1) - key = torch.cat((key, key_pass), dim=-1) - - # Cache QKV values - if has_layer_past: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = (key, value) if use_cache else None - # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision target_dtype = value.dtype if query.dtype != target_dtype: @@ -516,6 +497,90 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention): ) +class GPTNeoXSdpaAttention(GPTNeoXAttention): + """ + GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass + to adapt to the SDPA API. + """ + + def __init__(self, config): + super().__init__(config) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + if output_attentions or head_mask is not None: + logger.warning_once( + "`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + # Apply attention-specific projections and rope + query, key, value, present = self._attn_projections_and_rope( + hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache + ) + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + target_dtype = value.dtype + if query.dtype != target_dtype: + query = query.to(target_dtype) + if key.dtype != target_dtype: + key = key.to(target_dtype) + + # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA + if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=is_causal, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + return attn_output, present, None + + def attention_mask_func(attention_scores, ltor_mask): attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) return attention_scores @@ -660,6 +725,7 @@ class GPTNeoXMLP(nn.Module): GPT_NEOX_ATTENTION_CLASSES = { "eager": GPTNeoXAttention, "flash_attention_2": GPTNeoXFlashAttention2, + "sdpa": GPTNeoXSdpaAttention, } @@ -786,7 +852,8 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): self.emb_dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self._attn_implementation = config._attn_implementation self.gradient_checkpointing = False @@ -859,27 +926,29 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + # Attention mask. if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None + elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) else: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -888,9 +957,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - hidden_states = self.emb_dropout(inputs_embeds) if self.gradient_checkpointing and self.training: diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ed5bcac55e..51a4d235c3 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -19,7 +19,7 @@ import unittest from parameterized import parameterized from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -396,6 +396,68 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Based on tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate + which also overwrites the common test as the test is flaky on tiny models. + """ + max_new_tokens = 30 + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b") + + model_sdpa = GPTNeoXForCausalLM.from_pretrained( + "EleutherAI/pythia-1b", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = GPTNeoXForCausalLM.from_pretrained( + "EleutherAI/pythia-1b", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + texts = [ + "hi here's a longer context, getting longer and", + "Hello this is a very long sentence my friend, very long for real", + "Today I am in Paris and", + ] + + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + + with self.subTest(f"{padding_side}"): + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + ) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase):