Llama: fix custom 4D masks, v2 (#30348)
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -4277,6 +4277,80 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertFalse(fa2_correctly_converted)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
# Sequence in which all but the last token is the same
|
||||
input_ids = torch.tensor(
|
||||
[[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
|
||||
)
|
||||
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
||||
|
||||
# Combining common prefix with the unique ending tokens:
|
||||
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
|
||||
|
||||
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 1, 0],
|
||||
[1, 1, 1, 0, 0, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
)
|
||||
# inverting the attention mask
|
||||
mask_dtype = torch.float32
|
||||
min_dtype = torch.finfo(mask_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
|
||||
|
||||
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(input_ids, position_ids=position_ids).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids_shared_prefix,
|
||||
attention_mask=mask_shared_prefix,
|
||||
position_ids=position_ids_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing greedily-chosen tokens:
|
||||
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user