[GPTNeo] create local attention mask ones (#11335)

* create local attention mask ones

* remove old method, address patricks comment
This commit is contained in:
Suraj Patil
2021-04-20 18:37:44 +05:30
committed by GitHub
parent f464f10a2c
commit cfd2eaa8cf
2 changed files with 83 additions and 65 deletions

View File

@@ -36,7 +36,7 @@ if is_torch_available():
GPTNeoForCausalLM,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
class GPTNeoModelTester:
@@ -497,12 +497,14 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
def test_create_attention_mask(self):
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
layer = GPTNeoLocalSelfAttention(config)
window_size = config.window_size
batch_size, seq_length = 8, 1
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
# causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device
)
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
@@ -516,8 +518,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
causal_mask = layer._create_attention_mask(
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# causal_mask = layer._create_attention_mask(
# batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# )
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device, attention_mask
)
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
@@ -539,8 +544,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
mask_tokens = 3
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, model.config.window_size, torch_device, attention_mask
)
_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
_, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True)
# the last 3 tokens will be in the last block, and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))