4D attention_mask support (#27539)
* edits to _prepare_4d_causal_attention_mask() * initial tests for 4d mask * attention_mask_for_sdpa support * added test for inner model hidden * added autotest decorators * test mask dtype to torch.int64 * torch.testing.assert_close Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * torch_device and @torch_gpu in tests * upd tests * +torch decorators * torch decorators fixed * more decorators! * even more decorators * fewer decorators --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -302,10 +302,22 @@ def _prepare_4d_causal_attention_mask(
|
|||||||
key_value_length = input_shape[-1] + past_key_values_length
|
key_value_length = input_shape[-1] + past_key_values_length
|
||||||
|
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
if attention_mask is not None:
|
if attention_mask is not None and len(attention_mask.shape) == 2:
|
||||||
attention_mask = attn_mask_converter.to_4d(
|
attention_mask = attn_mask_converter.to_4d(
|
||||||
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
||||||
)
|
)
|
||||||
|
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
||||||
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
||||||
|
if tuple(attention_mask.shape) != expected_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
||||||
|
inverted_mask = 1.0 - attention_mask
|
||||||
|
attention_mask = inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = attn_mask_converter.to_causal_4d(
|
attention_mask = attn_mask_converter.to_causal_4d(
|
||||||
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||||
@@ -340,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||||||
is_tracing = torch.jit.is_tracing()
|
is_tracing = torch.jit.is_tracing()
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if torch.all(attention_mask == 1):
|
# 4d mask is passed through
|
||||||
|
if len(attention_mask.shape) == 4:
|
||||||
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
||||||
|
if tuple(attention_mask.shape) != expected_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
||||||
|
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
||||||
|
attention_mask = inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
elif torch.all(attention_mask == 1):
|
||||||
if is_tracing:
|
if is_tracing:
|
||||||
pass
|
pass
|
||||||
elif query_length == 1:
|
elif query_length == 1:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import copy
|
import copy
|
||||||
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -49,6 +50,7 @@ from transformers.testing_utils import (
|
|||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_usr_bin_time,
|
require_usr_bin_time,
|
||||||
slow,
|
slow,
|
||||||
@@ -1875,3 +1877,134 @@ class TestAttentionImplementation(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
|
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class Mask4DTestBase(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def get_test_data(self):
|
||||||
|
texts = ["the cat sat", "the cat had", "the cat is"]
|
||||||
|
encoded = [self.tokenizer.encode(t) for t in texts]
|
||||||
|
input_0 = torch.tensor(encoded, device=torch_device)
|
||||||
|
# tensor([[ 1, 278, 6635, 3290],
|
||||||
|
# [ 1, 278, 6635, 750],
|
||||||
|
# [ 1, 278, 6635, 338]], device='cuda:0')
|
||||||
|
|
||||||
|
# Combining common prefix with the unique ending tokens:
|
||||||
|
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
|
||||||
|
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
|
||||||
|
|
||||||
|
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
||||||
|
mask_1 = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 0, 1, 0],
|
||||||
|
[1, 1, 1, 0, 0, 1],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
],
|
||||||
|
device="cuda:0",
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||||
|
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||||
|
|
||||||
|
return input_0, input_1, mask_1, position_ids_1
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class Mask4DTestFP32(Mask4DTestBase):
|
||||||
|
def setUp(self):
|
||||||
|
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||||
|
model_dtype = torch.float32
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
|
||||||
|
|
||||||
|
def test_attention(self):
|
||||||
|
"""comparing outputs of attention layer"""
|
||||||
|
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||||
|
|
||||||
|
hid_0 = self.model.model.embed_tokens(input_0)
|
||||||
|
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
|
||||||
|
# outs_0.shape == torch.Size([3, 4, 768])
|
||||||
|
|
||||||
|
hid_1 = self.model.model.embed_tokens(input_1)
|
||||||
|
outs_1 = self.model.model.layers[0].self_attn.forward(
|
||||||
|
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
|
||||||
|
)[0]
|
||||||
|
# outs_1.shape == torch.Size([1, 6, 768])
|
||||||
|
|
||||||
|
outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
|
||||||
|
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
|
||||||
|
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)
|
||||||
|
|
||||||
|
def test_inner_model(self):
|
||||||
|
"""comparing hidden outputs of whole inner model"""
|
||||||
|
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||||
|
|
||||||
|
logits_0 = self.model.forward(input_0).logits
|
||||||
|
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||||
|
|
||||||
|
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||||
|
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits_0_last_tokens,
|
||||||
|
logits_1_last_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_causal_model_logits(self):
|
||||||
|
"""comparing logits outputs of whole inner model"""
|
||||||
|
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||||
|
|
||||||
|
logits_0 = self.model.forward(input_0).logits
|
||||||
|
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||||
|
|
||||||
|
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||||
|
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits_0_last_tokens,
|
||||||
|
logits_1_last_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class Mask4DTestFP16(Mask4DTestBase):
|
||||||
|
test_attention = Mask4DTestFP32.test_attention
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||||
|
model_dtype = torch.float16
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
|
||||||
|
|
||||||
|
def test_causal_model_logits(self):
|
||||||
|
"""comparing logits outputs of whole inner model"""
|
||||||
|
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||||
|
|
||||||
|
logits_0 = self.model.forward(input_0).logits
|
||||||
|
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||||
|
|
||||||
|
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||||
|
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||||
|
|
||||||
|
indices_0 = logits_0_last_tokens.sort(descending=True).indices
|
||||||
|
indices_1 = logits_1_last_tokens.sort(descending=True).indices
|
||||||
|
|
||||||
|
# checking logits, but note relaxed tolerances for FP16
|
||||||
|
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
|
||||||
|
|
||||||
|
# checking tokens order for the top tokens
|
||||||
|
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
|
||||||
|
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
|
||||||
|
|||||||
Reference in New Issue
Block a user