[⚠️ removed a default argument] Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) (#27868)
* remove bugged torch.float32 default * add test * fix tests * fix test * fix doc
This commit is contained in:
@@ -33,7 +33,7 @@ class AttentionMaskConverter:
|
|||||||
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
|
||||||
>>> converter = AttentionMaskConverter(True)
|
>>> converter = AttentionMaskConverter(True)
|
||||||
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5)
|
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
||||||
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
||||||
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
||||||
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
||||||
@@ -66,7 +66,7 @@ class AttentionMaskConverter:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
query_length: int,
|
query_length: int,
|
||||||
key_value_length: int,
|
key_value_length: int,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype,
|
||||||
device: Union[torch.device, "str"] = "cpu",
|
device: Union[torch.device, "str"] = "cpu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -98,8 +98,8 @@ class AttentionMaskConverter:
|
|||||||
self,
|
self,
|
||||||
attention_mask_2d: torch.Tensor,
|
attention_mask_2d: torch.Tensor,
|
||||||
query_length: int,
|
query_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
key_value_length: Optional[int] = None,
|
key_value_length: Optional[int] = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
||||||
@@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask(
|
|||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attn_mask_converter.to_4d(
|
attention_mask = attn_mask_converter.to_4d(
|
||||||
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
|
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = attn_mask_converter.to_causal_4d(
|
attention_mask = attn_mask_converter.to_causal_4d(
|
||||||
|
|||||||
@@ -85,7 +85,12 @@ if is_torch_available():
|
|||||||
T5Config,
|
T5Config,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
AttentionMaskConverter,
|
||||||
|
_create_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
)
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
# Fake pretrained models for tests
|
# Fake pretrained models for tests
|
||||||
@@ -150,6 +155,32 @@ if is_torch_available():
|
|||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
self.decoder.weight = self.base.linear.weight
|
self.decoder.weight = self.base.linear.weight
|
||||||
|
|
||||||
|
class Prepare4dCausalAttentionMaskModel(nn.Module):
|
||||||
|
def forward(self, inputs_embeds):
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
past_key_values_length = 4
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
None, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
class Create4dCausalAttentionMaskModel(nn.Module):
|
||||||
|
def forward(self, inputs_embeds):
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
past_key_values_length = 4
|
||||||
|
attention_mask = _create_4d_causal_attention_mask(
|
||||||
|
(batch_size, seq_length),
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
class Prepare4dAttentionMaskModel(nn.Module):
|
||||||
|
def forward(self, mask, inputs_embeds):
|
||||||
|
attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from transformers import FlaxBertModel
|
from transformers import FlaxBertModel
|
||||||
@@ -1493,7 +1524,7 @@ class AttentionMaskTester(unittest.TestCase):
|
|||||||
for bsz_idx, seq_idx in additional_mask:
|
for bsz_idx, seq_idx in additional_mask:
|
||||||
mask_2d[bsz_idx, seq_idx] = 0
|
mask_2d[bsz_idx, seq_idx] = 0
|
||||||
|
|
||||||
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
|
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)
|
||||||
|
|
||||||
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
|
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
|
||||||
|
|
||||||
@@ -1529,7 +1560,9 @@ class AttentionMaskTester(unittest.TestCase):
|
|||||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||||
|
|
||||||
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
||||||
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
|
mask_4d = mask_converter.to_causal_4d(
|
||||||
|
bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
if q_len == 1 and mask_converter.sliding_window is None:
|
if q_len == 1 and mask_converter.sliding_window is None:
|
||||||
# no causal mask if q_len is 1
|
# no causal mask if q_len is 1
|
||||||
@@ -1621,3 +1654,38 @@ class AttentionMaskTester(unittest.TestCase):
|
|||||||
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
||||||
# non auto-regressive case
|
# non auto-regressive case
|
||||||
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
||||||
|
|
||||||
|
def test_torch_compile_fullgraph(self):
|
||||||
|
model = Prepare4dCausalAttentionMaskModel()
|
||||||
|
|
||||||
|
inputs_embeds = torch.rand([1, 3, 32])
|
||||||
|
res_non_compiled = model(inputs_embeds)
|
||||||
|
|
||||||
|
compiled_model = torch.compile(model, fullgraph=True)
|
||||||
|
|
||||||
|
res_compiled = compiled_model(inputs_embeds)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
|
||||||
|
|
||||||
|
model = Create4dCausalAttentionMaskModel()
|
||||||
|
|
||||||
|
inputs_embeds = torch.rand(2, 4, 16)
|
||||||
|
res_non_compiled = model(inputs_embeds)
|
||||||
|
|
||||||
|
compiled_model = torch.compile(model, fullgraph=True)
|
||||||
|
res_compiled = compiled_model(inputs_embeds)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
|
||||||
|
|
||||||
|
model = Prepare4dAttentionMaskModel()
|
||||||
|
|
||||||
|
mask = torch.ones(2, 4)
|
||||||
|
mask[0, :2] = 0
|
||||||
|
inputs_embeds = torch.rand(2, 4, 16)
|
||||||
|
|
||||||
|
res_non_compiled = model(mask, inputs_embeds)
|
||||||
|
|
||||||
|
compiled_model = torch.compile(model, fullgraph=True)
|
||||||
|
res_compiled = compiled_model(mask, inputs_embeds)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
|
||||||
|
|||||||
Reference in New Issue
Block a user