From a0779b9e19093dc0371abbf516030491eec3d86c Mon Sep 17 00:00:00 2001 From: Poedator <24738311+poedator@users.noreply.github.com> Date: Mon, 13 May 2024 13:46:06 +0200 Subject: [PATCH] Llama: fix custom 4D masks, v2 (#30348) * 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante --- src/transformers/modeling_attn_mask_utils.py | 26 +- .../modeling_bigbird_pegasus.py | 1 - .../models/cohere/modeling_cohere.py | 40 +-- src/transformers/models/dbrx/modeling_dbrx.py | 40 +-- .../models/gemma/modeling_gemma.py | 40 +-- .../models/llama/modeling_llama.py | 40 +-- src/transformers/models/olmo/modeling_olmo.py | 40 +-- tests/models/llama/test_modeling_llama.py | 257 +++++++++++++++++- tests/models/mistral/test_modeling_mistral.py | 124 +++++++++ tests/test_modeling_common.py | 74 +++++ tests/test_modeling_utils.py | 225 --------------- 11 files changed, 541 insertions(+), 366 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 8dcf40268d..fb85d018c9 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -250,7 +250,7 @@ class AttentionMaskConverter: allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). """ - batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] key_value_length = query_length + past_key_values_length is_tracing = ( @@ -275,11 +275,7 @@ class AttentionMaskConverter: ignore_causal_mask = True elif sliding_window is None or key_value_length < sliding_window: if len(attention_mask.shape) == 4: - expected_shape = (batch_size, 1, query_length, key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) + return False elif (is_training or not is_tracing) and torch.all(attention_mask == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. @@ -387,12 +383,18 @@ def _prepare_4d_causal_attention_mask_for_sdpa( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) + if attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 74ec4432a5..b4e6419f99 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch BigBirdPegasus model.""" - import copy import math from typing import List, Optional, Tuple, Union diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d96131d770..b25528dfe7 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -995,37 +995,27 @@ class CohereModel(CoherePreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 2e185aa885..38c1fc814b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1241,37 +1241,27 @@ class DbrxModel(DbrxPreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8f78937047..12d01a6ea0 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -986,37 +986,27 @@ class GemmaModel(GemmaPreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d840b03faf..c6da59fcfb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1073,37 +1073,27 @@ class LlamaModel(LlamaPreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5009ac84be..6a7b2f748f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1052,37 +1052,27 @@ class OlmoModel(OlmoPreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index e63e537974..5d402bd859 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch LLaMA model. """ +"""Testing suite for the PyTorch LLaMA model.""" +import gc import tempfile import unittest @@ -21,7 +22,7 @@ import pytest from packaging import version from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed +from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -804,7 +805,7 @@ end '
 \ndef main():\n    factory = InterfaceManagerFactory(start=datetime.now())\n    managers = []\n    for i in range(10):\n        managers.append(factory.build(id=i))\n  class InterfaceManagerFactory(AbstractManagerFactory):\n    def __init__(',
             '
  = 0 :=\nbegin\nsplit,\n{ intros h f,\n    rw pi_1_etalisation at h,\n    simp [h],\n    refl\n},\n{ intro h,\n    have := @quasi_adjoint C D P,\n    simp [←pi_1_etalisation, this, h],\n    refl\n}\nend\n  /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
         ]
-        EXPECTED_IDS = torch.tensor([[    1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
+        EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
         # fmt: on
         self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
         input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
@@ -816,3 +817,253 @@ end
         ]
         infilling = tokenizer.batch_decode(generated_ids)
         self.assertEqual(infilling, EXPECTED_INFILLING)
+
+
+@slow
+@require_torch_gpu
+class Mask4DTestHard(unittest.TestCase):
+    def tearDown(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def setUp(self):
+        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
+        self.model_dtype = torch.float32
+        self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
+        self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
+
+    def get_test_data(self):
+        template = "my favorite {}"
+        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
+
+        batch_separate = [template.format(x) for x in items]  # 3 separate lines
+        batch_shared_prefix = template.format(" ".join(items))  # 1 line with options concatenated
+
+        input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
+        input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
+
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
+                    ]
+                ]
+            ],
+            device=torch_device,
+        )
+
+        position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
+
+        # building custom positions ids based on custom mask
+        position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
+        # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
+
+        # inverting the mask
+        min_dtype = torch.finfo(self.model_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_stacked_causal_mask(self):
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # 2 forward runs with custom 4D masks
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
+        past_key_values_a = outs_1a["past_key_values"]
+
+        # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+        outs_1b = self.model.forward(
+            input_1b,
+            attention_mask=mask_1b,
+            position_ids=position_ids_1b,
+            past_key_values=past_key_values_a,
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
+
+    def test_stacked_causal_mask_static_cache(self):
+        """same as above but with StaticCache"""
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # upgrade the model with StaticCache
+        max_cache_len = 16  # note that max_cache_len is greater than the attention_mask.shape[-1]
+        past_key_values = StaticCache(
+            config=self.model.config,
+            max_batch_size=1,
+            max_cache_len=max_cache_len,
+            device=torch_device,
+            dtype=self.model.dtype,
+        )
+
+        padded_attention_mask = torch.nn.functional.pad(
+            input=mask_shared_prefix,
+            pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
+            mode="constant",
+            value=torch.finfo(self.model_dtype).min,
+        )
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix,
+            attention_mask=padded_attention_mask,
+            position_ids=position_ids_shared_prefix,
+            cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
+            past_key_values=past_key_values,
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask_static_cache(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+        # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # upgrade the model with StaticCache
+        max_cache_len = 16  # note that max_cache_len is greater than the attention_mask.shape[-1]
+        past_key_values = StaticCache(
+            config=self.model.config,
+            max_batch_size=1,
+            max_cache_len=max_cache_len,
+            device=torch_device,
+            dtype=self.model.dtype,
+        )
+
+        # forward run for the first part of input
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        padded_mask_1a = torch.nn.functional.pad(
+            input=mask_1a,
+            pad=(0, max_cache_len - mask_1a.shape[-1]),
+            mode="constant",
+            value=torch.finfo(self.model_dtype).min,
+        )
+
+        _ = self.model.forward(
+            input_1a,
+            attention_mask=padded_mask_1a,
+            position_ids=position_ids_1a,
+            cache_position=torch.arange(part_a, device=torch_device),
+            past_key_values=past_key_values,
+        )
+
+        # forward run for the second part of input
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+
+        padded_mask_1b = torch.nn.functional.pad(
+            input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
+        )
+
+        outs_1b = self.model.forward(
+            input_1b,
+            attention_mask=padded_mask_1b,
+            position_ids=position_ids_1b,
+            cache_position=torch.arange(
+                part_a,
+                input_ids_shared_prefix.shape[-1],
+                device=torch_device,
+            ),
+            past_key_values=past_key_values,
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index 3500024b3e..bbc36c050e 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -627,3 +627,127 @@ class MistralIntegrationTest(unittest.TestCase):
         del model
         backend_empty_cache(torch_device)
         gc.collect()
+
+
+@slow
+@require_torch_gpu
+class Mask4DTestHard(unittest.TestCase):
+    def tearDown(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def setUp(self):
+        model_name = "mistralai/Mistral-7B-v0.1"
+        self.model_dtype = torch.float32
+        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
+        self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
+
+    def get_test_data(self):
+        template = "my favorite {}"
+        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
+
+        batch_separate = [template.format(x) for x in items]  # 3 separate lines
+        batch_shared_prefix = template.format(" ".join(items))  # 1 line with options concatenated
+
+        input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
+        input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
+
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
+                    ]
+                ]
+            ],
+            device=torch_device,
+        )
+
+        position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
+
+        # building custom positions ids based on custom mask
+        position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
+        # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
+
+        # inverting the mask
+        min_dtype = torch.finfo(self.model_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_stacked_causal_mask(self):
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # 2 forward runs with custom 4D masks
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
+        past_key_values_a = outs_1a["past_key_values"]
+
+        # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+        outs_1b = self.model.forward(
+            input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index df585f4afc..daa438e9f1 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4277,6 +4277,80 @@ class ModelTesterMixin:
 
                 self.assertFalse(fa2_correctly_converted)
 
+    def _get_custom_4d_mask_test_data(self):
+        # Sequence in which all but the last token is the same
+        input_ids = torch.tensor(
+            [[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
+        )
+        position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
+
+        # Combining common prefix with the unique ending tokens:
+        input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
+
+        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0],
+                        [1, 1, 1, 0, 1, 0],
+                        [1, 1, 1, 0, 0, 1],
+                    ]
+                ]
+            ],
+        )
+        # inverting the attention mask
+        mask_dtype = torch.float32
+        min_dtype = torch.finfo(mask_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
+
+        # Creating a position_ids tensor. note the repeating figures in the end.
+        position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_custom_4d_attention_mask(self):
+        if len(self.all_generative_model_classes) == 0:
+            self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
+
+        for model_class in self.all_generative_model_classes:
+            if not model_class._supports_cache_class:
+                self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
+            config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+            model = model_class(config).to(device=torch_device, dtype=torch.float32)
+
+            (
+                input_ids,
+                position_ids,
+                input_ids_shared_prefix,
+                mask_shared_prefix,
+                position_ids_shared_prefix,
+            ) = self._get_custom_4d_mask_test_data()
+
+            logits = model.forward(input_ids, position_ids=position_ids).logits
+            # logits.shape == torch.Size([3, 4, ...])
+
+            logits_shared_prefix = model(
+                input_ids_shared_prefix,
+                attention_mask=mask_shared_prefix,
+                position_ids=position_ids_shared_prefix,
+            )[0]
+            # logits_shared_prefix.shape == torch.Size([1, 6, ...])
+
+            out_last_tokens = logits[:, -1, :]  # last tokens in each batch line
+            out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :]  # last three tokens
+
+            # comparing greedily-chosen tokens:
+            assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
+
+            # comparing softmax-normalized logits:
+            normalized_0 = F.softmax(out_last_tokens)
+            normalized_1 = F.softmax(out_shared_prefix_last_tokens)
+            torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
+
 
 global_rng = random.Random()
 
diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py
index f98e1a2a23..9a00340d14 100755
--- a/tests/test_modeling_utils.py
+++ b/tests/test_modeling_utils.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import copy
-import gc
 import glob
 import json
 import os
@@ -53,7 +52,6 @@ from transformers.testing_utils import (
     require_tf,
     require_torch,
     require_torch_accelerator,
-    require_torch_gpu,
     require_torch_multi_accelerator,
     require_usr_bin_time,
     slow,
@@ -2107,229 +2105,6 @@ class TestAttentionImplementation(unittest.TestCase):
         self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
 
 
-@require_torch_gpu
-class Mask4DTestBase(unittest.TestCase):
-    def tearDown(self):
-        gc.collect()
-        torch.cuda.empty_cache()
-
-    def get_test_data(self):
-        texts = ["the cat sat", "the cat had", "the cat is"]
-        encoded = [self.tokenizer.encode(t) for t in texts]
-        input_0 = torch.tensor(encoded, device=torch_device)
-        # tensor([[   1,  278, 6635, 3290],
-        # [   1,  278, 6635,  750],
-        # [   1,  278, 6635,  338]], device='cuda:0')
-
-        position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
-
-        # Combining common prefix with the unique ending tokens:
-        input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
-        # tensor([[   1,  278, 6635, 3290,  750,  338]], device='cuda:0')
-
-        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
-        mask_1 = torch.tensor(
-            [
-                [
-                    [
-                        [1, 0, 0, 0, 0, 0],
-                        [1, 1, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0],
-                        [1, 1, 1, 1, 0, 0],
-                        [1, 1, 1, 0, 1, 0],
-                        [1, 1, 1, 0, 0, 1],
-                    ]
-                ]
-            ],
-            device="cuda:0",
-            dtype=torch.int64,
-        )
-
-        # Creating a position_ids tensor. note the repeating figures in the end.
-        position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
-
-        return input_0, position_ids_0, input_1, mask_1, position_ids_1
-
-
-@require_torch_gpu
-class Mask4DTestFP32(Mask4DTestBase):
-    def setUp(self):
-        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
-        self.model_dtype = torch.float32
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def test_attention(self):
-        """comparing outputs of attention layer"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-        causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
-
-        hid_0 = self.model.model.embed_tokens(input_0)
-        outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
-        # outs_0.shape == torch.Size([3, 4, 768])
-
-        hid_1 = self.model.model.embed_tokens(input_1)
-        outs_1 = self.model.model.layers[0].self_attn.forward(
-            hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
-        )[0]
-        # outs_1.shape == torch.Size([1, 6, 768])
-
-        outs_0_last_tokens = outs_0[:, -1, :]  # last tokens in each batch line
-        outs_1_last_tokens = outs_1[0, -3:, :]  # last three tokens
-        torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
-
-    def test_causal_model_logits(self):
-        """comparing logits outputs of whole inner model"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-
-        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
-        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
-        torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
-
-
-@require_torch_gpu
-class Mask4DTestFP16(Mask4DTestBase):
-    test_attention = Mask4DTestFP32.test_attention
-
-    def setUp(self):
-        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
-        self.model_dtype = torch.float16
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def test_causal_model_logits(self):
-        """comparing logits outputs of whole inner model"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-
-        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
-        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
-
-        indices_0 = logits_0_last_tokens.sort(descending=True).indices
-        indices_1 = logits_1_last_tokens.sort(descending=True).indices
-
-        # checking logits, but note relaxed tolerances for FP16
-        torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
-
-        # checking tokens order for the top tokens
-        for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
-            self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
-
-
-@slow
-@require_torch_gpu
-class Mask4DTestHard(unittest.TestCase):
-    def tearDown(self):
-        gc.collect()
-        torch.cuda.empty_cache()
-
-    def setUp(self):
-        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
-        self.model_dtype = torch.float32
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def get_test_data(self):
-        template = "my favorite {}"
-        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
-
-        batch_0 = [template.format(x) for x in items]  # 3 separate lines
-        batch_1 = template.format(" ".join(items))  # 1 line with options concatenated
-
-        input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
-        input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)
-
-        mask_1 = torch.tensor(
-            [
-                [
-                    [
-                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
-                    ]
-                ]
-            ],
-            device=torch_device,
-            dtype=torch.int64,
-        )
-
-        position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
-        # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
-        position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1)  # same but nicer
-
-        return input_0, position_ids_0, input_1, mask_1, position_ids_1
-
-    def test_stacked_causal_mask(self):
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        # regular batch
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_0_last = logits_0[:, -1, :]  # last tokens in each batch line
-        decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
-
-        # single forward run with 4D custom mask
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-        logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :]  # last three tokens
-        decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]
-
-        self.assertEqual(decoded_0, decoded_1)
-
-    def test_partial_stacked_causal_mask(self):
-        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
-        # masks
-
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        # regular batch
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_0_last = logits_0[:, -1, :]  # last tokens in each batch line
-        decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
-
-        # 2 forward runs with custom 4D masks
-        part_a = 3  # split point
-
-        input_1a = input_1[:, :part_a]
-        position_ids_1a = position_ids_1[:, :part_a]
-        mask_1a = mask_1[:, :, :part_a, :part_a]
-
-        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
-        past_key_values_a = outs_1a["past_key_values"]
-
-        input_1b = input_1[:, part_a:]
-        position_ids_1b = position_ids_1[:, part_a:]
-        mask_1b = mask_1[:, :, part_a:, :]
-
-        outs_1b = self.model.forward(
-            input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
-        )
-
-        decoded_1b = [
-            self.tokenizer.decode(t)
-            for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
-        ]
-
-        self.assertEqual(decoded_0, decoded_1b)
-
-
 @require_torch
 class TestTensorSharing(TestCasePlus):
     def test_disjoint(self):