Llama: fix custom 4D masks, v2 (#30348)

* 4d mask fixes

* Update custom 4D mask logic

* test moved to mixin

* extra tests 4d mask

* upd 4d mask and StaticCache handling

* added Mask4DTestHard to mistral tests

* post-rebase fixes

* test fixes for StaticCache

* make fix-copies

* upd 1 after #30476

* fix common tests

* rm elif attention_mask.dim() == 4:

* tests combined, fixed, mixtral supported

* bigbird style chg reverted

* rm if attention_mask.dim() == 2

* modeling_llama formatting chg

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Poedator
2024-05-13 13:46:06 +02:00
committed by GitHub
parent 453893ed15
commit a0779b9e19
11 changed files with 541 additions and 366 deletions

View File

@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
"""Testing suite for the PyTorch LLaMA model."""
import gc
import tempfile
import unittest
@@ -21,7 +22,7 @@ import pytest
from packaging import version
from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
@@ -804,7 +805,7 @@ end
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(',
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
]
EXPECTED_IDS = torch.tensor([[ 1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
# fmt: on
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
@@ -816,3 +817,253 @@ end
]
infilling = tokenizer.batch_decode(generated_ids)
self.assertEqual(infilling, EXPECTED_INFILLING)
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_separate = [template.format(x) for x in items] # 3 separate lines
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
)
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
# building custom positions ids based on custom mask
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
# inverting the mask
min_dtype = torch.finfo(self.model_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_stacked_causal_mask(self):
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b,
attention_mask=mask_1b,
position_ids=position_ids_1b,
past_key_values=past_key_values_a,
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)
def test_stacked_causal_mask_static_cache(self):
"""same as above but with StaticCache"""
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
)
padded_attention_mask = torch.nn.functional.pad(
input=mask_shared_prefix,
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
mode="constant",
value=torch.finfo(self.model_dtype).min,
)
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix,
attention_mask=padded_attention_mask,
position_ids=position_ids_shared_prefix,
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
past_key_values=past_key_values,
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask_static_cache(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
)
# forward run for the first part of input
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
padded_mask_1a = torch.nn.functional.pad(
input=mask_1a,
pad=(0, max_cache_len - mask_1a.shape[-1]),
mode="constant",
value=torch.finfo(self.model_dtype).min,
)
_ = self.model.forward(
input_1a,
attention_mask=padded_mask_1a,
position_ids=position_ids_1a,
cache_position=torch.arange(part_a, device=torch_device),
past_key_values=past_key_values,
)
# forward run for the second part of input
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
padded_mask_1b = torch.nn.functional.pad(
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
)
outs_1b = self.model.forward(
input_1b,
attention_mask=padded_mask_1b,
position_ids=position_ids_1b,
cache_position=torch.arange(
part_a,
input_ids_shared_prefix.shape[-1],
device=torch_device,
),
past_key_values=past_key_values,
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)