Llama: fix custom 4D masks, v2 (#30348)
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -627,3 +627,127 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def setUp(self):
|
||||
model_name = "mistralai/Mistral-7B-v0.1"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
||||
|
||||
# building custom positions ids based on custom mask
|
||||
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
||||
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
|
||||
# inverting the mask
|
||||
min_dtype = torch.finfo(self.model_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
outs_1b = self.model.forward(
|
||||
input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
Reference in New Issue
Block a user