From 989ae326b572b3729b78ac8ac1c6c5859c104af7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Jul 2020 10:48:06 +0200 Subject: [PATCH] [Reformer] Adapt Reformer MaskedLM Attn mask (#5560) * fix attention mask * fix slow test * refactor attn masks * fix fp16 generate test --- src/transformers/modeling_reformer.py | 88 +++++++++++++-------------- tests/test_modeling_reformer.py | 7 ++- 2 files changed, 45 insertions(+), 50 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 89cbf14403..dfd3c63117 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -373,7 +373,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # use cached buckets for backprop only if buckets is None: # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors, num_hashes) + buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) assert ( int(buckets.shape[-1]) == num_hashes * sequence_length @@ -460,7 +460,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) - def _hash_vectors(self, vectors, num_hashes): + def _hash_vectors(self, vectors, num_hashes, attention_mask): batch_size = vectors.shape[0] # See https://arxiv.org/pdf/1509.02897.pdf @@ -514,6 +514,15 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): cur_product = cur_product * bucket_factor + if attention_mask is not None: + # add an extra bucket for padding tokens only + num_buckets = num_buckets + 1 + # assign padding tokens extra bucket + buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape) + buckets = torch.where( + buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) + ) + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. offsets = torch.arange(num_hashes, device=vectors.device) @@ -614,7 +623,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length) + mask = self._compute_attn_mask( + query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, sequence_length + ) if mask is not None: query_key_dots = torch.where(mask, query_key_dots, mask_value) @@ -669,45 +680,32 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): return out_vectors, logits, attention_probs - def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length): - mask = None + def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dot_shape, sequence_length): - # Causal mask - if self.is_decoder: - mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) - - # Attention mask: chunk, look up correct mask value from key_value_bucket_idx - # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. + # attention mask for LSH if attention_mask is not None: # if chunked attention, the attention mask has to correspond to LSH order + attention_mask = attention_mask.to(torch.uint8)[:, None, :] if sequence_length > self.chunk_length: - attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask[:, None, :] attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) - key_attn_mask = torch.gather(attention_mask, -1, key_indices) - query_attn_mask = torch.gather(attention_mask, -1, query_indices) - # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk - attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + # extract attention mask from LSH sorted key_indices + attention_mask = torch.gather(attention_mask, -1, key_indices) - # free memory - del query_attn_mask, key_attn_mask + attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape) + + # Causal mask + if self.is_decoder is True: + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # add attention mask if not None + if attention_mask is not None: + attention_mask = causal_mask * attention_mask else: - # usual attention mask creation - attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand( - query_indices.shape + attention_mask.shape[-1:] - ) + attention_mask = causal_mask - # free memory - del attention_mask - - # multiply by casaul mask if necessary - if mask is not None: - mask = mask * attn_mask - else: - mask = attn_mask - - return mask + return attention_mask def _len_and_dim_norm(self, vectors): """ @@ -923,7 +921,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length): - mask = None # chunk attention mask and look before and after if attention_mask is not None: @@ -931,24 +928,21 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): if self.chunk_length < sequence_length: attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) - attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) - else: - attention_mask_key = attention_mask + attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + # create attn_mask + attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape) # Causal mask if self.is_decoder is True: - mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) - # Attention mask - if attention_mask is not None: - # create attn_mask - attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape) - # multiply by casaul mask if necessary - if mask is not None: - mask = mask * attn_mask + # add attention mask if not None + if attention_mask is not None: + attention_mask = causal_mask * attention_mask else: - mask = attn_mask - return mask + attention_mask = causal_mask + + return attention_mask class ReformerSelfOutput(nn.Module): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 7faa43346e..ba1b2a449d 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -407,7 +407,8 @@ class ReformerModelTester: model.to(torch_device) model.half() model.eval() - output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) + # only use last 10 inputs for generation + output = model.generate(input_ids[:, -10:], attention_mask=input_mask, do_sample=False) self.parent.assertFalse(torch.isnan(output).any().item()) def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels): @@ -623,7 +624,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T @require_torch class ReformerIntegrationTests(unittest.TestCase): """ - These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`. + These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`. """ def _get_basic_config_and_input(self): @@ -940,7 +941,7 @@ class ReformerIntegrationTests(unittest.TestCase): hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[1, -1, :5] expected_output_slice = torch.tensor( - [0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device, + [0.0256, -0.0121, 0.0636, 0.0024, -0.0393], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))