Llama: fix custom 4D masks, v2 (#30348)

* 4d mask fixes

* Update custom 4D mask logic

* test moved to mixin

* extra tests 4d mask

* upd 4d mask and StaticCache handling

* added Mask4DTestHard to mistral tests

* post-rebase fixes

* test fixes for StaticCache

* make fix-copies

* upd 1 after #30476

* fix common tests

* rm elif attention_mask.dim() == 4:

* tests combined, fixed, mixtral supported

* bigbird style chg reverted

* rm if attention_mask.dim() == 2

* modeling_llama formatting chg

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Poedator
2024-05-13 13:46:06 +02:00
committed by GitHub
parent 453893ed15
commit a0779b9e19
11 changed files with 541 additions and 366 deletions

View File

@@ -4277,6 +4277,80 @@ class ModelTesterMixin:
self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(
[[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
)
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
# Combining common prefix with the unique ending tokens:
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
)
# inverting the attention mask
mask_dtype = torch.float32
min_dtype = torch.finfo(mask_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_custom_4d_attention_mask(self):
if len(self.all_generative_model_classes) == 0:
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(input_ids, position_ids=position_ids).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids_shared_prefix,
attention_mask=mask_shared_prefix,
position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
global_rng = random.Random()