[GPTNeo] create local attention mask ones (#11335)
* create local attention mask ones * remove old method, address patricks comment
This commit is contained in:
@@ -192,6 +192,57 @@ class GPTNeoAttentionMixin:
|
|||||||
padded_tensor = padded_tensor.transpose(-2, -1)
|
padded_tensor = padded_tensor.transpose(-2, -1)
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2):
|
||||||
|
"""
|
||||||
|
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
|
||||||
|
"""
|
||||||
|
batch_size = tensors.shape[0]
|
||||||
|
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)
|
||||||
|
|
||||||
|
if len(tensors.shape) == 3:
|
||||||
|
return torch.reshape(tensors, split_dim_shape + (-1,))
|
||||||
|
elif len(tensors.shape) == 2:
|
||||||
|
return torch.reshape(tensors, split_dim_shape)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None):
|
||||||
|
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
|
||||||
|
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)
|
||||||
|
|
||||||
|
query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length)
|
||||||
|
key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False)
|
||||||
|
|
||||||
|
# create mask tensor such that each block contains a causal_mask for that block
|
||||||
|
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# A block can also be padded becuase of the _look_back operation
|
||||||
|
# look back into the attention_block such that it will also get padded the same way
|
||||||
|
# and have 0s in the padded position
|
||||||
|
attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False)
|
||||||
|
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim
|
||||||
|
|
||||||
|
# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
|
||||||
|
# will contain 0s.
|
||||||
|
# This also makes sure that other positions ignored by the attention_mask will also be ignored
|
||||||
|
# in the causal_mask.
|
||||||
|
causal_mask = causal_mask * attention_mask
|
||||||
|
|
||||||
|
# In GPT Neo's local attention each window can attend to at most window_size tokens
|
||||||
|
# rest of the tokens should be ignored.
|
||||||
|
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
|
||||||
|
visible = torch.gt(relative_position, -window_size)
|
||||||
|
|
||||||
|
causal_mask = causal_mask * visible
|
||||||
|
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||||
"""
|
"""
|
||||||
Splits hidden_size dim into attn_head_size and num_heads
|
Splits hidden_size dim into attn_head_size and num_heads
|
||||||
@@ -218,20 +269,6 @@ class GPTNeoAttentionMixin:
|
|||||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
||||||
return tensor.view(new_shape)
|
return tensor.view(new_shape)
|
||||||
|
|
||||||
def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size):
|
|
||||||
"""
|
|
||||||
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
|
|
||||||
"""
|
|
||||||
batch_size = tensors.shape[0]
|
|
||||||
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)
|
|
||||||
|
|
||||||
if len(tensors.shape) == 3:
|
|
||||||
return torch.reshape(tensors, split_dim_shape + (hidden_size,))
|
|
||||||
elif len(tensors.shape) == 2:
|
|
||||||
return torch.reshape(tensors, split_dim_shape)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")
|
|
||||||
|
|
||||||
def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
|
def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
@@ -289,8 +326,8 @@ class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=None,
|
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
layer_past=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -357,45 +394,11 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
|||||||
|
|
||||||
self.window_size = config.window_size
|
self.window_size = config.window_size
|
||||||
|
|
||||||
def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None):
|
|
||||||
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)
|
|
||||||
|
|
||||||
query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim)
|
|
||||||
key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False)
|
|
||||||
|
|
||||||
# create mask tensor such that each block contains a causal_mask for that block
|
|
||||||
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
# A block can also be padded becuase of the _look_back operation
|
|
||||||
# look back into the attention_block such that it will also get padded the same way
|
|
||||||
# and have 0s in the padded position
|
|
||||||
attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False)
|
|
||||||
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim
|
|
||||||
|
|
||||||
# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
|
|
||||||
# will contain 0s.
|
|
||||||
# This also makes sure that other positions ignored by the attention_mask will also be ignored
|
|
||||||
# in the causal_mask.
|
|
||||||
causal_mask = causal_mask * attention_mask
|
|
||||||
|
|
||||||
# In GPT Neo's local attention each window can attend to at most window_size tokens
|
|
||||||
# rest of the tokens should be ignored.
|
|
||||||
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
|
|
||||||
visible = torch.gt(relative_position, -self.window_size)
|
|
||||||
|
|
||||||
causal_mask = causal_mask * visible
|
|
||||||
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
layer_past=None,
|
layer_past=None,
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -421,9 +424,9 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
|||||||
# create buckets
|
# create buckets
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
# we just need 1 block with block_length 1 when caching is enabled
|
# we just need 1 block with block_length 1 when caching is enabled
|
||||||
query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim)
|
query = self._split_seq_length_dim_to(query, 1, 1)
|
||||||
else:
|
else:
|
||||||
query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim)
|
query = self._split_seq_length_dim_to(query, num_blocks, block_length)
|
||||||
|
|
||||||
key = self._look_back(key, block_length, self.window_size)
|
key = self._look_back(key, block_length, self.window_size)
|
||||||
value = self._look_back(value, block_length, self.window_size)
|
value = self._look_back(value, block_length, self.window_size)
|
||||||
@@ -437,18 +440,16 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
|||||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
mask = self._create_attention_mask(
|
|
||||||
batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask
|
|
||||||
)
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block
|
# only take the mask for the last block
|
||||||
|
attention_mask = attention_mask[:, -1:, :, -1:, :]
|
||||||
|
|
||||||
# attn
|
# attn
|
||||||
attn_output, attn_weights = self._attn(
|
attn_output, attn_weights = self._attn(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
causal_mask=mask,
|
causal_mask=attention_mask,
|
||||||
masked_bias=self.masked_bias,
|
masked_bias=self.masked_bias,
|
||||||
attn_dropout=self.attn_dropout,
|
attn_dropout=self.attn_dropout,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -495,8 +496,8 @@ class GPTNeoAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
outputs = self.attention(
|
outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_past=layer_past,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -767,6 +768,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
past_key_values = tuple([None] * len(self.h))
|
past_key_values = tuple([None] * len(self.h))
|
||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
@@ -792,6 +795,13 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
global_attention_mask = None
|
global_attention_mask = None
|
||||||
|
|
||||||
|
# Local causal attention mask
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
full_seq_length = seq_length + past_length
|
||||||
|
local_attention_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||||
|
batch_size, full_seq_length, self.config.window_size, device, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x num_headss x N x N
|
# attention_probs has shape bsz x num_headss x N x N
|
||||||
@@ -816,7 +826,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
attn_type = self.config.attention_layers[i]
|
attn_type = self.config.attention_layers[i]
|
||||||
attn_mask = global_attention_mask if attn_type == "global" else attention_mask
|
attn_mask = global_attention_mask if attn_type == "global" else local_attention_mask
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ if is_torch_available():
|
|||||||
GPTNeoForCausalLM,
|
GPTNeoForCausalLM,
|
||||||
GPTNeoModel,
|
GPTNeoModel,
|
||||||
)
|
)
|
||||||
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
|
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoModelTester:
|
class GPTNeoModelTester:
|
||||||
@@ -497,12 +497,14 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_create_attention_mask(self):
|
def test_create_attention_mask(self):
|
||||||
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
|
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
|
||||||
layer = GPTNeoLocalSelfAttention(config)
|
|
||||||
window_size = config.window_size
|
window_size = config.window_size
|
||||||
batch_size, seq_length = 8, 1
|
batch_size, seq_length = 8, 1
|
||||||
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
|
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
|
||||||
|
|
||||||
causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
|
# causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
|
||||||
|
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||||
|
batch_size, seq_length, config.window_size, torch_device
|
||||||
|
)
|
||||||
# check shapes
|
# check shapes
|
||||||
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
|
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
|
||||||
self.assertListEqual(list(causal_mask.shape), expected_shape)
|
self.assertListEqual(list(causal_mask.shape), expected_shape)
|
||||||
@@ -516,8 +518,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
|||||||
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
|
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
|
||||||
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
|
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
|
||||||
|
|
||||||
causal_mask = layer._create_attention_mask(
|
# causal_mask = layer._create_attention_mask(
|
||||||
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
|
# batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
|
||||||
|
# )
|
||||||
|
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||||
|
batch_size, seq_length, config.window_size, torch_device, attention_mask
|
||||||
)
|
)
|
||||||
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
|
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
|
||||||
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
|
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
|
||||||
@@ -539,8 +544,11 @@ class GPTNeoLocalAttentionTest(unittest.TestCase):
|
|||||||
mask_tokens = 3
|
mask_tokens = 3
|
||||||
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
|
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
|
||||||
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
|
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
|
||||||
|
local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
|
||||||
|
batch_size, seq_length, model.config.window_size, torch_device, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
|
_, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True)
|
||||||
|
|
||||||
# the last 3 tokens will be in the last block, and should have 0 attn_probs
|
# the last 3 tokens will be in the last block, and should have 0 attn_probs
|
||||||
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
|
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
|
||||||
|
|||||||
Reference in New Issue
Block a user