Llama: fix custom 4D masks, v2 (#30348)

* 4d mask fixes

* Update custom 4D mask logic

* test moved to mixin

* extra tests 4d mask

* upd 4d mask and StaticCache handling

* added Mask4DTestHard to mistral tests

* post-rebase fixes

* test fixes for StaticCache

* make fix-copies

* upd 1 after #30476

* fix common tests

* rm elif attention_mask.dim() == 4:

* tests combined, fixed, mixtral supported

* bigbird style chg reverted

* rm if attention_mask.dim() == 2

* modeling_llama formatting chg

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Poedator
2024-05-13 13:46:06 +02:00
committed by GitHub
parent 453893ed15
commit a0779b9e19
11 changed files with 541 additions and 366 deletions

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import glob
import json
import os
@@ -53,7 +52,6 @@ from transformers.testing_utils import (
require_tf,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
require_usr_bin_time,
slow,
@@ -2107,229 +2105,6 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
@require_torch_gpu
class Mask4DTestBase(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
input_0 = torch.tensor(encoded, device=torch_device)
# tensor([[ 1, 278, 6635, 3290],
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')
position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
# Combining common prefix with the unique ending tokens:
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
device="cuda:0",
dtype=torch.int64,
)
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
return input_0, position_ids_0, input_1, mask_1, position_ids_1
@require_torch_gpu
class Mask4DTestFP32(Mask4DTestBase):
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def test_attention(self):
"""comparing outputs of attention layer"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
# outs_0.shape == torch.Size([3, 4, 768])
hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
)[0]
# outs_1.shape == torch.Size([1, 6, 768])
outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
@require_torch_gpu
class Mask4DTestFP16(Mask4DTestBase):
test_attention = Mask4DTestFP32.test_attention
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
indices_0 = logits_0_last_tokens.sort(descending=True).indices
indices_1 = logits_1_last_tokens.sort(descending=True).indices
# checking logits, but note relaxed tolerances for FP16
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
# checking tokens order for the top tokens
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_0 = [template.format(x) for x in items] # 3 separate lines
batch_1 = template.format(" ".join(items)) # 1 line with options concatenated
input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
dtype=torch.int64,
)
position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer
return input_0, position_ids_0, input_1, mask_1, position_ids_1
def test_stacked_causal_mask(self):
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens
decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]
self.assertEqual(decoded_0, decoded_1)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
# masks
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_1[:, :part_a]
position_ids_1a = position_ids_1[:, :part_a]
mask_1a = mask_1[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
input_1b = input_1[:, part_a:]
position_ids_1b = position_ids_1[:, part_a:]
mask_1b = mask_1[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
]
self.assertEqual(decoded_0, decoded_1b)
@require_torch
class TestTensorSharing(TestCasePlus):
def test_disjoint(self):