From 4294f0c3583c3a361406b2b6e8bda05ad1af459e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 19 Mar 2024 17:32:01 +0000 Subject: [PATCH] Llama: partial 4d masks (#29731) * partial 4d masks * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/cohere/modeling_cohere.py | 14 ++- .../models/gemma/modeling_gemma.py | 14 ++- .../models/llama/modeling_llama.py | 14 ++- tests/test_modeling_utils.py | 110 +++++++++++++++++- 4 files changed, 142 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a559f37bac..4460d6ce2e 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -899,7 +899,7 @@ class CohereModel(CoherePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -967,7 +967,7 @@ class CohereModel(CoherePreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor): + def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -993,9 +993,17 @@ class CohereModel(CoherePreTrainedModel): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: + offset = past_seen_tokens + else: + offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8ec5d64ade..ad7cc769be 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -901,7 +901,7 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -975,7 +975,7 @@ class GemmaModel(GemmaPreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor): + def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1002,9 +1002,17 @@ class GemmaModel(GemmaPreTrainedModel): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: + offset = past_seen_tokens + else: + offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7b9cca330e..ae45c8b170 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1000,7 +1000,7 @@ class LlamaModel(LlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -1068,7 +1068,7 @@ class LlamaModel(LlamaPreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor): + def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1094,9 +1094,17 @@ class LlamaModel(LlamaPreTrainedModel): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: + offset = past_seen_tokens + else: + offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index c88b9c8870..46df1feae9 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1956,7 +1956,6 @@ class TestAttentionImplementation(unittest.TestCase): self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception)) -@slow @require_torch_gpu class Mask4DTestBase(unittest.TestCase): def tearDown(self): @@ -2011,6 +2010,7 @@ class Mask4DTestFP32(Mask4DTestBase): def test_attention(self): """comparing outputs of attention layer""" + # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min @@ -2030,6 +2030,7 @@ class Mask4DTestFP32(Mask4DTestBase): def test_causal_model_logits(self): """comparing logits outputs of whole inner model""" + # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits @@ -2052,6 +2053,7 @@ class Mask4DTestFP16(Mask4DTestBase): def test_causal_model_logits(self): """comparing logits outputs of whole inner model""" + # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits @@ -2069,3 +2071,109 @@ class Mask4DTestFP16(Mask4DTestBase): # checking tokens order for the top tokens for token_ids_0, token_ids_1 in zip(indices_0, indices_1): self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128])) + + +@slow +@require_torch_gpu +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def setUp(self): + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_0 = [template.format(x) for x in items] # 3 separate lines + batch_1 = template.format(" ".join(items)) # 1 line with options concatenated + + input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device) + input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device) + + mask_1 = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + dtype=torch.int64, + ) + + position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device) + # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer + + return input_0, position_ids_0, input_1, mask_1, position_ids_1 + + def test_stacked_causal_mask(self): + # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention + input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + # regular batch + logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits + logits_0_last = logits_0[:, -1, :] # last tokens in each batch line + decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits + logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens + decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)] + + self.assertEqual(decoded_0, decoded_1) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention + # masks + + # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention + input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + # regular batch + logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits + logits_0_last = logits_0[:, -1, :] # last tokens in each batch line + decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_1[:, :part_a] + position_ids_1a = position_ids_1[:, :part_a] + mask_1a = mask_1[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + input_1b = input_1[:, part_a:] + position_ids_1b = position_ids_1[:, part_a:] + mask_1b = mask_1[:, :, part_a:, :] + + outs_1b = self.model.forward( + input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a + ) + + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a] + ] + + self.assertEqual(decoded_0, decoded_1b)