[GPTNeo] create local attention mask ones (#11335)
* create local attention mask ones * remove old method, address patricks comment
This commit is contained in:
@@ -36,7 +36,7 @@ if is_torch_available():
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoModel,
|
||||
)
|
||||
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
|
||||
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
|
||||
|
||||
|
||||
class GPTNeoModelTester:
|
||||
@@ -497,12 +497,14 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
||||
|
||||
def test_create_attention_mask(self):
|
||||
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
|
||||
layer = GPTNeoLocalSelfAttention(config)
|
||||
window_size = config.window_size
|
||||
batch_size, seq_length = 8, 1
|
||||
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
|
||||
|
||||
causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
|
||||
# causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
|
||||
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||
batch_size, seq_length, config.window_size, torch_device
|
||||
)
|
||||
# check shapes
|
||||
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
|
||||
self.assertListEqual(list(causal_mask.shape), expected_shape)
|
||||
@@ -516,8 +518,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
||||
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
|
||||
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
|
||||
|
||||
causal_mask = layer._create_attention_mask(
|
||||
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
|
||||
# causal_mask = layer._create_attention_mask(
|
||||
# batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
|
||||
# )
|
||||
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||
batch_size, seq_length, config.window_size, torch_device, attention_mask
|
||||
)
|
||||
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
|
||||
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
|
||||
@@ -539,8 +544,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
||||
mask_tokens = 3
|
||||
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
|
||||
local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||
batch_size, seq_length, model.config.window_size, torch_device, attention_mask
|
||||
)
|
||||
|
||||
_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
|
||||
_, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True)
|
||||
|
||||
# the last 3 tokens will be in the last block, and should have 0 attn_probs
|
||||
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
|
||||
|
||||
Reference in New Issue
Block a user