From 9ca485734aea269961d63a040ff194365d151fd1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 Jun 2020 23:08:39 +0200 Subject: [PATCH] [Reformer] Improved memory if input is shorter than chunk length (#4720) * improve handling of short inputs for reformer * correct typo in assert statement * fix other tests --- src/transformers/modeling_reformer.py | 235 +++++++++++++++----------- src/transformers/modeling_utils.py | 2 +- tests/test_modeling_reformer.py | 20 ++- 3 files changed, 161 insertions(+), 96 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index d48823a412..c24a4219ea 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -351,41 +351,48 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): value_vectors.shape[-1], self.attention_head_size ) - # set `num_buckets` on the fly, recommended way to do it - if self.num_buckets is None: - self._set_num_buckets(sequence_length) + # LSH attention only makes sense if chunked attention should be performed + if self.chunk_length < sequence_length: + # set `num_buckets` on the fly, recommended way to do it + if self.num_buckets is None: + self._set_num_buckets(sequence_length) - # use cached buckets for backprop only - if buckets is None: - # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors, num_hashes) + # use cached buckets for backprop only + if buckets is None: + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors, num_hashes) - assert ( - int(buckets.shape[-1]) == num_hashes * sequence_length - ), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) - - sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( - sequence_length, buckets, num_hashes - ) - - # make sure bucket idx is not longer then sequence length - sorted_bucket_idx = sorted_bucket_idx % sequence_length - - # cluster query key value vectors according to hashed buckets - query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) - value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) - - query_key_vectors = self._split_seq_length_dim_to( - query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) - value_vectors = self._split_seq_length_dim_to( - value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) - - if self.chunk_length is None: assert ( - self.num_chunks_before == 0 and self.num_chunks_after == 0 - ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + int(buckets.shape[-1]) == num_hashes * sequence_length + ), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) + + sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( + sequence_length, buckets, num_hashes + ) + + # make sure bucket idx is not longer then sequence length + sorted_bucket_idx = sorted_bucket_idx % sequence_length + + # cluster query key value vectors according to hashed buckets + query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) + value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) + + query_key_vectors = self._split_seq_length_dim_to( + query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + + if self.chunk_length is None: + assert ( + self.num_chunks_before == 0 and self.num_chunks_after == 0 + ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + else: + # get sequence length indices + sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat( + batch_size, self.num_attention_heads, 1 + ) # scale key vectors key_vectors = self._len_and_dim_norm(query_key_vectors) @@ -398,31 +405,35 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): sorted_bucket_idx=sorted_bucket_idx, attention_mask=attention_mask, head_mask=head_mask, + sequence_length=sequence_length, ) + # free memory del query_key_vectors, key_vectors, value_vectors - # sort clusters back to correct ordering - out_vectors, logits = ReverseSort.apply( - out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes - ) - - # sum up all hash rounds - if num_hashes > 1: - out_vectors = self._split_seq_length_dim_to( - out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, + # re-order out_vectors and logits + if self.chunk_length < sequence_length: + # sort clusters back to correct ordering + out_vectors, logits = ReverseSort.apply( + out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes ) - logits = self._split_seq_length_dim_to( - logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, - ).unsqueeze(-1) - probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) - out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + # sum up all hash rounds + if num_hashes > 1: + out_vectors = self._split_seq_length_dim_to( + out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, + ) + logits = self._split_seq_length_dim_to( + logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, + ).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + # free memory + del probs_vectors + # free memory - del probs_vectors - - # free memory - del logits + del logits assert out_vectors.shape == ( batch_size, @@ -554,10 +565,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): self.num_buckets = num_buckets def _attend( - self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, + self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length ): - key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) - value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + + # look at previous and following chunks if chunked attention + if self.chunk_length < sequence_length: + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) @@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # free memory del query_vectors, key_vectors - query_bucket_idx = self._split_seq_length_dim_to( - sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads - ) - key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) + # if chunked attention split bucket idxs to query and key + if self.chunk_length < sequence_length: + query_bucket_idx = self._split_seq_length_dim_to( + sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads + ) + key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) + else: + query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx # get correct mask values depending on precision if query_key_dots.dtype == torch.float16: @@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask) + mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length) if mask is not None: query_key_dots = torch.where(mask, query_key_dots, mask_value) @@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): del value_vectors # merge chunk length - logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) - out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + if self.chunk_length < sequence_length: + logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) return out_vectors, logits, attention_probs - def _compute_attn_mask(self, query_indices, key_indices, attention_mask): + def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length): mask = None # Causal mask @@ -642,15 +661,27 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # Attention mask: chunk, look up correct mask value from key_value_bucket_idx # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. if attention_mask is not None: - attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] - # expand attn_mask to fit with key_value_bucket_idx shape - attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) - key_attn_mask = torch.gather(attention_mask, -1, key_indices) - query_attn_mask = torch.gather(attention_mask, -1, query_indices) - # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk - attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + # if chunked attention, the attention mask has to correspond to LSH order + if sequence_length > self.chunk_length: + attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] + # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) + key_attn_mask = torch.gather(attention_mask, -1, key_indices) + query_attn_mask = torch.gather(attention_mask, -1, query_indices) + # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk + attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + + # free memory + del query_attn_mask, key_attn_mask + else: + # usual attention mask creation + attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand( + query_indices.shape + attention_mask.shape[-1:] + ) + # free memory - del query_attn_mask, key_attn_mask, attention_mask + del attention_mask # multiply by casaul mask if necessary if mask is not None: @@ -810,36 +841,45 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) ) - # chunk vectors - # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size - query_vectors = self._split_seq_length_dim_to( - query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) - key_vectors = self._split_seq_length_dim_to( - key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) - value_vectors = self._split_seq_length_dim_to( - value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) - - # chunk indices + # get sequence length indices indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) - query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) - key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) - # append chunks before and after - key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) - value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) - key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) + # if input should be chunked + if self.chunk_length < sequence_length: + # chunk vectors + # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + query_vectors = self._split_seq_length_dim_to( + query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + key_vectors = self._split_seq_length_dim_to( + key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + # chunk indices + query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + + # append chunks before and after + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) + else: + query_indices = key_indices = indices + + # query-key matmul: QK^T query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory del query_vectors, key_vectors - mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape) + mask = self._compute_attn_mask( + query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length + ) if mask is not None: # get mask tensor depending on half precision or not @@ -874,7 +914,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): del value_vectors # merge chunk length - out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + if self.chunk_length < sequence_length: + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) @@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) - def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape): + def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length): mask = None # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) - attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + + if self.chunk_length < sequence_length: + attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) + attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + else: + attention_mask_key = attention_mask # Causal mask if self.is_decoder is True: @@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel): # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) - must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 + must_pad_to_match_chunk_length = ( + input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length + ) if must_pad_to_match_chunk_length: padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 42ebf0057e..a04b21c6b3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2188,7 +2188,7 @@ def apply_chunking_to_forward( assert ( input_tensors[0].shape[chunk_dim] % chunk_size == 0 ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( - input_tensors[0][chunk_dim], chunk_size + input_tensors[0].shape[chunk_dim], chunk_size ) num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 60ba91ada9..7ad651918b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -388,6 +388,16 @@ class ReformerModelTester: output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) self.parent.assertFalse(torch.isnan(output).any().item()) + def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask): + # force chunk length to be bigger than input_ids + config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] + config.local_attn_chunk_length = 2 * input_ids.shape[-1] + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + output_logits = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -433,6 +443,10 @@ class ReformerTesterMixin: config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) + def test_reformer_no_chunking(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs) + @slow def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase): def test_lsh_layer_forward(self): config = self._get_basic_config_and_input() + config["lsh_num_chunks_before"] = 0 config["attn_layers"] = ["lsh"] config["is_decoder"] = False hidden_states = self._get_hidden_states() @@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase): def test_lsh_layer_forward_complex(self): config = self._get_basic_config_and_input() + config["lsh_num_chunks_before"] = 0 config["attn_layers"] = ["lsh"] config["num_buckets"] = [2, 4] attn_mask = self._get_attn_mask() @@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase): def test_local_layer_forward(self): config = self._get_basic_config_and_input() + config["local_num_chunks_before"] = 0 config["attn_layers"] = ["local"] config["is_decoder"] = False hidden_states = self._get_hidden_states() @@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase): def test_local_layer_forward_complex(self): config = self._get_basic_config_and_input() + config["local_num_chunks_before"] = 0 config["attn_layers"] = ["local"] attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() @@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase): reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, + [1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))