Llama: fix custom 4D masks, v2 (#30348)
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -12,8 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch LLaMA model. """
|
||||
"""Testing suite for the PyTorch LLaMA model."""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -21,7 +22,7 @@ import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
@@ -804,7 +805,7 @@ end
|
||||
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(',
|
||||
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
|
||||
]
|
||||
EXPECTED_IDS = torch.tensor([[ 1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
|
||||
EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
|
||||
# fmt: on
|
||||
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
|
||||
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
|
||||
@@ -816,3 +817,253 @@ end
|
||||
]
|
||||
infilling = tokenizer.batch_decode(generated_ids)
|
||||
self.assertEqual(infilling, EXPECTED_INFILLING)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def setUp(self):
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
||||
|
||||
# building custom positions ids based on custom mask
|
||||
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
||||
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
|
||||
# inverting the mask
|
||||
min_dtype = torch.finfo(self.model_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
past_key_values=past_key_values_a,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
def test_stacked_causal_mask_static_cache(self):
|
||||
"""same as above but with StaticCache"""
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
padded_attention_mask = torch.nn.functional.pad(
|
||||
input=mask_shared_prefix,
|
||||
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix,
|
||||
attention_mask=padded_attention_mask,
|
||||
position_ids=position_ids_shared_prefix,
|
||||
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask_static_cache(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
# forward run for the first part of input
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
padded_mask_1a = torch.nn.functional.pad(
|
||||
input=mask_1a,
|
||||
pad=(0, max_cache_len - mask_1a.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
_ = self.model.forward(
|
||||
input_1a,
|
||||
attention_mask=padded_mask_1a,
|
||||
position_ids=position_ids_1a,
|
||||
cache_position=torch.arange(part_a, device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
# forward run for the second part of input
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
|
||||
padded_mask_1b = torch.nn.functional.pad(
|
||||
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
|
||||
)
|
||||
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=padded_mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
cache_position=torch.arange(
|
||||
part_a,
|
||||
input_ids_shared_prefix.shape[-1],
|
||||
device=torch_device,
|
||||
),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
@@ -627,3 +627,127 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def setUp(self):
|
||||
model_name = "mistralai/Mistral-7B-v0.1"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
||||
|
||||
# building custom positions ids based on custom mask
|
||||
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
||||
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
|
||||
# inverting the mask
|
||||
min_dtype = torch.finfo(self.model_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
outs_1b = self.model.forward(
|
||||
input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
@@ -4277,6 +4277,80 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertFalse(fa2_correctly_converted)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
# Sequence in which all but the last token is the same
|
||||
input_ids = torch.tensor(
|
||||
[[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
|
||||
)
|
||||
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
||||
|
||||
# Combining common prefix with the unique ending tokens:
|
||||
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
|
||||
|
||||
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 1, 0],
|
||||
[1, 1, 1, 0, 0, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
)
|
||||
# inverting the attention mask
|
||||
mask_dtype = torch.float32
|
||||
min_dtype = torch.finfo(mask_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
|
||||
|
||||
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(input_ids, position_ids=position_ids).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids_shared_prefix,
|
||||
attention_mask=mask_shared_prefix,
|
||||
position_ids=position_ids_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing greedily-chosen tokens:
|
||||
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
@@ -53,7 +52,6 @@ from transformers.testing_utils import (
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_usr_bin_time,
|
||||
slow,
|
||||
@@ -2107,229 +2105,6 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class Mask4DTestBase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_test_data(self):
|
||||
texts = ["the cat sat", "the cat had", "the cat is"]
|
||||
encoded = [self.tokenizer.encode(t) for t in texts]
|
||||
input_0 = torch.tensor(encoded, device=torch_device)
|
||||
# tensor([[ 1, 278, 6635, 3290],
|
||||
# [ 1, 278, 6635, 750],
|
||||
# [ 1, 278, 6635, 338]], device='cuda:0')
|
||||
|
||||
position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
||||
|
||||
# Combining common prefix with the unique ending tokens:
|
||||
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
|
||||
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
|
||||
|
||||
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
||||
mask_1 = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 1, 0],
|
||||
[1, 1, 1, 0, 0, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device="cuda:0",
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||
|
||||
return input_0, position_ids_0, input_1, mask_1, position_ids_1
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class Mask4DTestFP32(Mask4DTestBase):
|
||||
def setUp(self):
|
||||
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def test_attention(self):
|
||||
"""comparing outputs of attention layer"""
|
||||
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
|
||||
|
||||
hid_0 = self.model.model.embed_tokens(input_0)
|
||||
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
|
||||
# outs_0.shape == torch.Size([3, 4, 768])
|
||||
|
||||
hid_1 = self.model.model.embed_tokens(input_1)
|
||||
outs_1 = self.model.model.layers[0].self_attn.forward(
|
||||
hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
|
||||
)[0]
|
||||
# outs_1.shape == torch.Size([1, 6, 768])
|
||||
|
||||
outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
|
||||
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
|
||||
torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
|
||||
|
||||
def test_causal_model_logits(self):
|
||||
"""comparing logits outputs of whole inner model"""
|
||||
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
|
||||
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class Mask4DTestFP16(Mask4DTestBase):
|
||||
test_attention = Mask4DTestFP32.test_attention
|
||||
|
||||
def setUp(self):
|
||||
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||
self.model_dtype = torch.float16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def test_causal_model_logits(self):
|
||||
"""comparing logits outputs of whole inner model"""
|
||||
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
|
||||
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||
|
||||
indices_0 = logits_0_last_tokens.sort(descending=True).indices
|
||||
indices_1 = logits_1_last_tokens.sort(descending=True).indices
|
||||
|
||||
# checking logits, but note relaxed tolerances for FP16
|
||||
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
|
||||
|
||||
# checking tokens order for the top tokens
|
||||
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
|
||||
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def setUp(self):
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_0 = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_1 = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_1 = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
|
||||
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer
|
||||
|
||||
return input_0, position_ids_0, input_1, mask_1, position_ids_1
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
|
||||
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens
|
||||
decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded_0, decoded_1)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
|
||||
# masks
|
||||
|
||||
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
|
||||
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_1[:, :part_a]
|
||||
position_ids_1a = position_ids_1[:, :part_a]
|
||||
mask_1a = mask_1[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
input_1b = input_1[:, part_a:]
|
||||
position_ids_1b = position_ids_1[:, part_a:]
|
||||
mask_1b = mask_1[:, :, part_a:, :]
|
||||
|
||||
outs_1b = self.model.forward(
|
||||
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
|
||||
)
|
||||
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
|
||||
]
|
||||
|
||||
self.assertEqual(decoded_0, decoded_1b)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestTensorSharing(TestCasePlus):
|
||||
def test_disjoint(self):
|
||||
|
||||
Reference in New Issue
Block a user