[Reformer] Adapt Reformer MaskedLM Attn mask (#5560)
* fix attention mask * fix slow test * refactor attn masks * fix fp16 generate test
This commit is contained in:
committed by
GitHub
parent
3dcb748e31
commit
989ae326b5
@@ -373,7 +373,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
# use cached buckets for backprop only
|
# use cached buckets for backprop only
|
||||||
if buckets is None:
|
if buckets is None:
|
||||||
# hash query key vectors into buckets
|
# hash query key vectors into buckets
|
||||||
buckets = self._hash_vectors(query_key_vectors, num_hashes)
|
buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
int(buckets.shape[-1]) == num_hashes * sequence_length
|
int(buckets.shape[-1]) == num_hashes * sequence_length
|
||||||
@@ -460,7 +460,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
|
|
||||||
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
|
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
|
||||||
|
|
||||||
def _hash_vectors(self, vectors, num_hashes):
|
def _hash_vectors(self, vectors, num_hashes, attention_mask):
|
||||||
batch_size = vectors.shape[0]
|
batch_size = vectors.shape[0]
|
||||||
|
|
||||||
# See https://arxiv.org/pdf/1509.02897.pdf
|
# See https://arxiv.org/pdf/1509.02897.pdf
|
||||||
@@ -514,6 +514,15 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
|
|
||||||
cur_product = cur_product * bucket_factor
|
cur_product = cur_product * bucket_factor
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# add an extra bucket for padding tokens only
|
||||||
|
num_buckets = num_buckets + 1
|
||||||
|
# assign padding tokens extra bucket
|
||||||
|
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape)
|
||||||
|
buckets = torch.where(
|
||||||
|
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
|
||||||
|
)
|
||||||
|
|
||||||
# buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len).
|
# buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len).
|
||||||
# Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
|
# Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
|
||||||
offsets = torch.arange(num_hashes, device=vectors.device)
|
offsets = torch.arange(num_hashes, device=vectors.device)
|
||||||
@@ -614,7 +623,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
self_mask_value = self.self_mask_value_float32
|
self_mask_value = self.self_mask_value_float32
|
||||||
mask_value = self.mask_value_float32
|
mask_value = self.mask_value_float32
|
||||||
|
|
||||||
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
|
mask = self._compute_attn_mask(
|
||||||
|
query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
query_key_dots = torch.where(mask, query_key_dots, mask_value)
|
query_key_dots = torch.where(mask, query_key_dots, mask_value)
|
||||||
@@ -669,45 +680,32 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
|
|
||||||
return out_vectors, logits, attention_probs
|
return out_vectors, logits, attention_probs
|
||||||
|
|
||||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
|
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dot_shape, sequence_length):
|
||||||
mask = None
|
|
||||||
|
|
||||||
# Causal mask
|
# attention mask for LSH
|
||||||
if self.is_decoder:
|
|
||||||
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
|
|
||||||
|
|
||||||
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
|
|
||||||
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# if chunked attention, the attention mask has to correspond to LSH order
|
# if chunked attention, the attention mask has to correspond to LSH order
|
||||||
if sequence_length > self.chunk_length:
|
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
|
|
||||||
# expand attn_mask to fit with key_value_bucket_idx shape
|
|
||||||
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
|
|
||||||
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
|
|
||||||
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
|
|
||||||
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
|
|
||||||
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
|
|
||||||
|
|
||||||
# free memory
|
|
||||||
del query_attn_mask, key_attn_mask
|
|
||||||
else:
|
|
||||||
# usual attention mask creation
|
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||||
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
|
if sequence_length > self.chunk_length:
|
||||||
query_indices.shape + attention_mask.shape[-1:]
|
# expand attn_mask to fit with key_value_bucket_idx shape
|
||||||
)
|
attention_mask = attention_mask[:, None, :]
|
||||||
|
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
|
||||||
|
# extract attention mask from LSH sorted key_indices
|
||||||
|
attention_mask = torch.gather(attention_mask, -1, key_indices)
|
||||||
|
|
||||||
# free memory
|
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape)
|
||||||
del attention_mask
|
|
||||||
|
|
||||||
# multiply by casaul mask if necessary
|
# Causal mask
|
||||||
if mask is not None:
|
if self.is_decoder is True:
|
||||||
mask = mask * attn_mask
|
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
|
||||||
|
|
||||||
|
# add attention mask if not None
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = causal_mask * attention_mask
|
||||||
else:
|
else:
|
||||||
mask = attn_mask
|
attention_mask = causal_mask
|
||||||
|
|
||||||
return mask
|
return attention_mask
|
||||||
|
|
||||||
def _len_and_dim_norm(self, vectors):
|
def _len_and_dim_norm(self, vectors):
|
||||||
"""
|
"""
|
||||||
@@ -923,7 +921,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
|
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
|
||||||
|
|
||||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
|
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
|
||||||
mask = None
|
|
||||||
|
|
||||||
# chunk attention mask and look before and after
|
# chunk attention mask and look before and after
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -931,24 +928,21 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
|
|
||||||
if self.chunk_length < sequence_length:
|
if self.chunk_length < sequence_length:
|
||||||
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
||||||
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
|
attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
|
||||||
else:
|
# create attn_mask
|
||||||
attention_mask_key = attention_mask
|
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape)
|
||||||
|
|
||||||
# Causal mask
|
# Causal mask
|
||||||
if self.is_decoder is True:
|
if self.is_decoder is True:
|
||||||
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
|
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
|
||||||
|
|
||||||
# Attention mask
|
# add attention mask if not None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# create attn_mask
|
attention_mask = causal_mask * attention_mask
|
||||||
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape)
|
|
||||||
# multiply by casaul mask if necessary
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask * attn_mask
|
|
||||||
else:
|
else:
|
||||||
mask = attn_mask
|
attention_mask = causal_mask
|
||||||
return mask
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class ReformerSelfOutput(nn.Module):
|
class ReformerSelfOutput(nn.Module):
|
||||||
|
|||||||
@@ -407,7 +407,8 @@ class ReformerModelTester:
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.half()
|
model.half()
|
||||||
model.eval()
|
model.eval()
|
||||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
# only use last 10 inputs for generation
|
||||||
|
output = model.generate(input_ids[:, -10:], attention_mask=input_mask, do_sample=False)
|
||||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||||
@@ -623,7 +624,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
|||||||
@require_torch
|
@require_torch
|
||||||
class ReformerIntegrationTests(unittest.TestCase):
|
class ReformerIntegrationTests(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`.
|
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_basic_config_and_input(self):
|
def _get_basic_config_and_input(self):
|
||||||
@@ -940,7 +941,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||||
output_slice = hidden_states[1, -1, :5]
|
output_slice = hidden_states[1, -1, :5]
|
||||||
expected_output_slice = torch.tensor(
|
expected_output_slice = torch.tensor(
|
||||||
[0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device,
|
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], dtype=torch.float, device=torch_device,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user