From 307a7d0be81f44734ed4ef4bc5613ac4d7f8bdec Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 8 Dec 2023 10:44:47 +0100 Subject: [PATCH] =?UTF-8?q?[=E2=9A=A0=EF=B8=8F=20removed=20a=20default=20a?= =?UTF-8?q?rgument]=20Make=20`AttentionMaskConverter`=20compatible=20with?= =?UTF-8?q?=20`torch.compile(...,=20fullgraph=3DTrue)`=20(#27868)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove bugged torch.float32 default * add test * fix tests * fix test * fix doc --- src/transformers/modeling_attn_mask_utils.py | 8 +-- tests/test_modeling_utils.py | 74 +++++++++++++++++++- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 9658adc55d..2c4a839ae1 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -33,7 +33,7 @@ class AttentionMaskConverter: >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], @@ -66,7 +66,7 @@ class AttentionMaskConverter: batch_size: int, query_length: int, key_value_length: int, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype, device: Union[torch.device, "str"] = "cpu", ) -> torch.Tensor: """ @@ -98,8 +98,8 @@ class AttentionMaskConverter: self, attention_mask_2d: torch.Tensor, query_length: int, + dtype: torch.dtype, key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, @@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask( # 4d mask is passed through the layers if attention_mask is not None: attention_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype ) else: attention_mask = attn_mask_converter.to_causal_4d( diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index fd4297de1c..e1c37ec268 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -85,7 +85,12 @@ if is_torch_available(): T5Config, T5ForConditionalGeneration, ) - from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + ) from transformers.modeling_utils import shard_checkpoint # Fake pretrained models for tests @@ -150,6 +155,32 @@ if is_torch_available(): def tie_weights(self): self.decoder.weight = self.base.linear.weight + class Prepare4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _prepare_4d_causal_attention_mask( + None, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + return attention_mask + + class Create4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _create_4d_causal_attention_mask( + (batch_size, seq_length), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + return attention_mask + + class Prepare4dAttentionMaskModel(nn.Module): + def forward(self, mask, inputs_embeds): + attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype) + return attention_mask + if is_flax_available(): from transformers import FlaxBertModel @@ -1493,7 +1524,7 @@ class AttentionMaskTester(unittest.TestCase): for bsz_idx, seq_idx in additional_mask: mask_2d[bsz_idx, seq_idx] = 0 - mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32) assert mask_4d.shape == (bsz, 1, q_len, kv_len) @@ -1529,7 +1560,9 @@ class AttentionMaskTester(unittest.TestCase): self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): - mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + mask_4d = mask_converter.to_causal_4d( + bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32 + ) if q_len == 1 and mask_converter.sliding_window is None: # no causal mask if q_len is 1 @@ -1621,3 +1654,38 @@ class AttentionMaskTester(unittest.TestCase): self.check_to_causal(mask_converter, q_len=3, kv_len=7) # non auto-regressive case self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_torch_compile_fullgraph(self): + model = Prepare4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand([1, 3, 32]) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Create4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand(2, 4, 16) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Prepare4dAttentionMaskModel() + + mask = torch.ones(2, 4) + mask[0, :2] = 0 + inputs_embeds = torch.rand(2, 4, 16) + + res_non_compiled = model(mask, inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(mask, inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled))