From 1e21c4fbe04a77e1bb414ed7869bc69219d955eb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 13 Mar 2024 15:07:52 +0000 Subject: [PATCH] Llama: allow custom 4d masks (#29618) --- .../models/gemma/modeling_gemma.py | 13 +++-- .../models/llama/modeling_llama.py | 13 +++-- tests/test_modeling_utils.py | 50 +++++++------------ 3 files changed, 35 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index cbb074fcc1..7156fe8e5a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -975,11 +975,16 @@ class GemmaModel(GemmaPreTrainedModel): causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype causal_mask = causal_mask.expand(batch_size, 1, -1, -1) - if attention_mask is not None and attention_mask.dim() == 2: + if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3752a92c83..5e2a9c2e5b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1083,11 +1083,16 @@ class LlamaModel(LlamaPreTrainedModel): min_dtype = torch.finfo(dtype).min causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype causal_mask = causal_mask.expand(batch_size, 1, -1, -1) - if attention_mask is not None and attention_mask.dim() == 2: + if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 1b2351be93..4bc66b7575 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1992,6 +1992,8 @@ class Mask4DTestBase(unittest.TestCase): # [ 1, 278, 6635, 750], # [ 1, 278, 6635, 338]], device='cuda:0') + position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64) + # Combining common prefix with the unique ending tokens: input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0) # tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0') @@ -2017,81 +2019,63 @@ class Mask4DTestBase(unittest.TestCase): # Creating a position_ids tensor. note the repeating figures in the end. position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64) - return input_0, input_1, mask_1, position_ids_1 + return input_0, position_ids_0, input_1, mask_1, position_ids_1 -@slow @require_torch_gpu class Mask4DTestFP32(Mask4DTestBase): def setUp(self): model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow - model_dtype = torch.float32 + self.model_dtype = torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) def test_attention(self): """comparing outputs of attention layer""" - input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() + causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min hid_0 = self.model.model.embed_tokens(input_0) - outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0] + outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0] # outs_0.shape == torch.Size([3, 4, 768]) hid_1 = self.model.model.embed_tokens(input_1) outs_1 = self.model.model.layers[0].self_attn.forward( - hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1 + hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1 )[0] # outs_1.shape == torch.Size([1, 6, 768]) outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens - assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens) - - def test_inner_model(self): - """comparing hidden outputs of whole inner model""" - input_0, input_1, mask_1, position_ids_1 = self.get_test_data() - - logits_0 = self.model.forward(input_0).logits - logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits - - logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line - logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens - torch.testing.assert_close( - logits_0_last_tokens, - logits_1_last_tokens, - ) + torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens) def test_causal_model_logits(self): """comparing logits outputs of whole inner model""" - input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - logits_0 = self.model.forward(input_0).logits + logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens - torch.testing.assert_close( - logits_0_last_tokens, - logits_1_last_tokens, - ) + torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens) -@slow @require_torch_gpu class Mask4DTestFP16(Mask4DTestBase): test_attention = Mask4DTestFP32.test_attention def setUp(self): model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow - model_dtype = torch.float16 + self.model_dtype = torch.float16 self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) def test_causal_model_logits(self): """comparing logits outputs of whole inner model""" - input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data() - logits_0 = self.model.forward(input_0).logits + logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line