[Attention Mask] Refactor all encoder-decoder attention mask (#27086)
* [FA2 Bart] Add FA2 to all Bart-like * better * Refactor attention mask * remove all customized atteniton logic * format * mass rename * replace _expand_mask * replace _expand_mask * mass rename * add pt files * mass replace & rename * mass replace & rename * mass replace & rename * mass replace & rename * Update src/transformers/models/idefics/modeling_idefics.py * fix more * clean more * fix more * make style * fix again * finish * finish * finish * finish * finish * finish * finish * finish * finish * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small fix mistral * finish * finish * finish * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
29c74f58ae
commit
ac5893756b
@@ -39,149 +39,6 @@ if is_torch_available():
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import AttnMaskConverter
|
||||
|
||||
|
||||
@require_torch
|
||||
class AttentionMaskTester(unittest.TestCase):
|
||||
def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
|
||||
mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
|
||||
mask_4d_values = mask_4d[:, 0][mask_indices]
|
||||
is_inf = mask_4d_values == -float("inf")
|
||||
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
|
||||
assert torch.logical_or(is_inf, is_min).all()
|
||||
|
||||
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
|
||||
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
|
||||
|
||||
if additional_mask is not None:
|
||||
for bsz_idx, seq_idx in additional_mask:
|
||||
mask_2d[bsz_idx, seq_idx] = 0
|
||||
|
||||
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
|
||||
|
||||
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
|
||||
|
||||
context = mask_converter.sliding_window
|
||||
if mask_converter.is_causal and context is None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
if 0 in mask_2d:
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
|
||||
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
||||
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
|
||||
|
||||
if q_len == 1 and mask_converter.sliding_window is None:
|
||||
# no causal mask if q_len is 1
|
||||
assert mask_4d is None
|
||||
return
|
||||
|
||||
context = mask_converter.sliding_window
|
||||
if mask_converter.is_causal and context is None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
|
||||
def compute_num_context_mask(self, kv_len, context, q_len):
|
||||
# This function computes the # of attention tokens that are added for
|
||||
# the sliding window
|
||||
c_mask_len = kv_len - context
|
||||
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
|
||||
cut_mask_len = max(c_mask_len - q_len, 0)
|
||||
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
|
||||
return num_mask_triangle - num_cut_mask
|
||||
|
||||
def test_2d_to_4d_causal(self):
|
||||
mask_converter = AttnMaskConverter(is_causal=True)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_2d_to_4d(self):
|
||||
torch.ones((3, 7), device=torch_device, dtype=torch.long)
|
||||
mask_converter = AttnMaskConverter(is_causal=False)
|
||||
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_2d_to_4d_causal_sliding(self):
|
||||
torch.ones((3, 7), device=torch_device, dtype=torch.long)
|
||||
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_causal_mask(self):
|
||||
mask_converter = AttnMaskConverter(is_causal=True)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
def test_causal_mask_sliding(self):
|
||||
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
|
||||
class LlamaModelTester:
|
||||
|
||||
@@ -48,6 +48,7 @@ from transformers.testing_utils import (
|
||||
require_torch_multi_gpu,
|
||||
require_usr_bin_time,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -79,6 +80,7 @@ if is_torch_available():
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
# Fake pretrained models for tests
|
||||
@@ -1184,3 +1186,143 @@ The commit description supports markdown synthax see:
|
||||
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||
|
||||
|
||||
@require_torch
|
||||
class AttentionMaskTester(unittest.TestCase):
|
||||
def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
|
||||
mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
|
||||
mask_4d_values = mask_4d[:, 0][mask_indices]
|
||||
is_inf = mask_4d_values == -float("inf")
|
||||
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
|
||||
assert torch.logical_or(is_inf, is_min).all()
|
||||
|
||||
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
|
||||
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
|
||||
|
||||
if additional_mask is not None:
|
||||
for bsz_idx, seq_idx in additional_mask:
|
||||
mask_2d[bsz_idx, seq_idx] = 0
|
||||
|
||||
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
|
||||
|
||||
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
|
||||
|
||||
context = mask_converter.sliding_window
|
||||
if mask_converter.is_causal and context is None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
if 0 in mask_2d:
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
|
||||
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
||||
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
|
||||
|
||||
if q_len == 1 and mask_converter.sliding_window is None:
|
||||
# no causal mask if q_len is 1
|
||||
assert mask_4d is None
|
||||
return
|
||||
|
||||
context = mask_converter.sliding_window
|
||||
if mask_converter.is_causal and context is None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
|
||||
def compute_num_context_mask(self, kv_len, context, q_len):
|
||||
# This function computes the # of attention tokens that are added for
|
||||
# the sliding window
|
||||
c_mask_len = kv_len - context
|
||||
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
|
||||
cut_mask_len = max(c_mask_len - q_len, 0)
|
||||
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
|
||||
return num_mask_triangle - num_cut_mask
|
||||
|
||||
def test_2d_to_4d_causal(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=True)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_2d_to_4d(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=False)
|
||||
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_2d_to_4d_causal_sliding(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=5)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
# same with extra attention masks
|
||||
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
def test_causal_mask(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=True)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
def test_causal_mask_sliding(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=3)
|
||||
|
||||
# auto-regressive use case
|
||||
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
|
||||
# special auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
|
||||
# non auto-regressive case
|
||||
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
|
||||
|
||||
Reference in New Issue
Block a user