[Reformer] Improved memory if input is shorter than chunk length (#4720)
* improve handling of short inputs for reformer * correct typo in assert statement * fix other tests
This commit is contained in:
committed by
GitHub
parent
b231a413f5
commit
9ca485734a
@@ -351,6 +351,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
value_vectors.shape[-1], self.attention_head_size
|
||||
)
|
||||
|
||||
# LSH attention only makes sense if chunked attention should be performed
|
||||
if self.chunk_length < sequence_length:
|
||||
# set `num_buckets` on the fly, recommended way to do it
|
||||
if self.num_buckets is None:
|
||||
self._set_num_buckets(sequence_length)
|
||||
@@ -386,6 +388,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
assert (
|
||||
self.num_chunks_before == 0 and self.num_chunks_after == 0
|
||||
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
|
||||
else:
|
||||
# get sequence length indices
|
||||
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
|
||||
batch_size, self.num_attention_heads, 1
|
||||
)
|
||||
|
||||
# scale key vectors
|
||||
key_vectors = self._len_and_dim_norm(query_key_vectors)
|
||||
@@ -398,10 +405,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
sorted_bucket_idx=sorted_bucket_idx,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
sequence_length=sequence_length,
|
||||
)
|
||||
|
||||
# free memory
|
||||
del query_key_vectors, key_vectors, value_vectors
|
||||
|
||||
# re-order out_vectors and logits
|
||||
if self.chunk_length < sequence_length:
|
||||
# sort clusters back to correct ordering
|
||||
out_vectors, logits = ReverseSort.apply(
|
||||
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
|
||||
@@ -554,8 +565,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
def _attend(
|
||||
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask,
|
||||
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
|
||||
):
|
||||
|
||||
# look at previous and following chunks if chunked attention
|
||||
if self.chunk_length < sequence_length:
|
||||
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||
|
||||
@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
# free memory
|
||||
del query_vectors, key_vectors
|
||||
|
||||
# if chunked attention split bucket idxs to query and key
|
||||
if self.chunk_length < sequence_length:
|
||||
query_bucket_idx = self._split_seq_length_dim_to(
|
||||
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
|
||||
)
|
||||
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
|
||||
else:
|
||||
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
|
||||
|
||||
# get correct mask values depending on precision
|
||||
if query_key_dots.dtype == torch.float16:
|
||||
@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
self_mask_value = self.self_mask_value_float32
|
||||
mask_value = self.mask_value_float32
|
||||
|
||||
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask)
|
||||
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
|
||||
|
||||
if mask is not None:
|
||||
query_key_dots = torch.where(mask, query_key_dots, mask_value)
|
||||
@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
del value_vectors
|
||||
|
||||
# merge chunk length
|
||||
if self.chunk_length < sequence_length:
|
||||
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
|
||||
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
||||
|
||||
return out_vectors, logits, attention_probs
|
||||
|
||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask):
|
||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
|
||||
mask = None
|
||||
|
||||
# Causal mask
|
||||
@@ -642,6 +661,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
|
||||
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
|
||||
if attention_mask is not None:
|
||||
# if chunked attention, the attention mask has to correspond to LSH order
|
||||
if sequence_length > self.chunk_length:
|
||||
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
|
||||
# expand attn_mask to fit with key_value_bucket_idx shape
|
||||
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
|
||||
@@ -649,8 +670,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
|
||||
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
|
||||
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
|
||||
|
||||
# free memory
|
||||
del query_attn_mask, key_attn_mask, attention_mask
|
||||
del query_attn_mask, key_attn_mask
|
||||
else:
|
||||
# usual attention mask creation
|
||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
|
||||
query_indices.shape + attention_mask.shape[-1:]
|
||||
)
|
||||
|
||||
# free memory
|
||||
del attention_mask
|
||||
|
||||
# multiply by casaul mask if necessary
|
||||
if mask is not None:
|
||||
@@ -810,6 +841,13 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
|
||||
)
|
||||
|
||||
# get sequence length indices
|
||||
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
|
||||
batch_size, self.num_attention_heads, 1
|
||||
)
|
||||
|
||||
# if input should be chunked
|
||||
if self.chunk_length < sequence_length:
|
||||
# chunk vectors
|
||||
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
|
||||
query_vectors = self._split_seq_length_dim_to(
|
||||
@@ -823,9 +861,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
)
|
||||
|
||||
# chunk indices
|
||||
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
|
||||
batch_size, self.num_attention_heads, 1
|
||||
)
|
||||
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
||||
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
||||
|
||||
@@ -833,13 +868,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
|
||||
else:
|
||||
query_indices = key_indices = indices
|
||||
|
||||
# query-key matmul: QK^T
|
||||
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
|
||||
|
||||
# free memory
|
||||
del query_vectors, key_vectors
|
||||
|
||||
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape)
|
||||
mask = self._compute_attn_mask(
|
||||
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# get mask tensor depending on half precision or not
|
||||
@@ -874,6 +914,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
del value_vectors
|
||||
|
||||
# merge chunk length
|
||||
if self.chunk_length < sequence_length:
|
||||
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
||||
|
||||
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
|
||||
@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
|
||||
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
|
||||
|
||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape):
|
||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
|
||||
mask = None
|
||||
|
||||
# chunk attention mask and look before and after
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||
|
||||
if self.chunk_length < sequence_length:
|
||||
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
||||
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
|
||||
else:
|
||||
attention_mask_key = attention_mask
|
||||
|
||||
# Causal mask
|
||||
if self.is_decoder is True:
|
||||
@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel):
|
||||
|
||||
# if needs padding
|
||||
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
|
||||
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0
|
||||
must_pad_to_match_chunk_length = (
|
||||
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length
|
||||
)
|
||||
|
||||
if must_pad_to_match_chunk_length:
|
||||
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
|
||||
|
||||
@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward(
|
||||
assert (
|
||||
input_tensors[0].shape[chunk_dim] % chunk_size == 0
|
||||
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
|
||||
input_tensors[0][chunk_dim], chunk_size
|
||||
input_tensors[0].shape[chunk_dim], chunk_size
|
||||
)
|
||||
|
||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||
|
||||
@@ -388,6 +388,16 @@ class ReformerModelTester:
|
||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
||||
# force chunk length to be bigger than input_ids
|
||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask,) = config_and_inputs
|
||||
@@ -433,6 +443,10 @@ class ReformerTesterMixin:
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||
|
||||
def test_reformer_no_chunking(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_dropout_random_seed_is_changing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_lsh_layer_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["lsh_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["lsh"]
|
||||
config["is_decoder"] = False
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_lsh_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["lsh_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["lsh"]
|
||||
config["num_buckets"] = [2, 4]
|
||||
attn_mask = self._get_attn_mask()
|
||||
@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_local_layer_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["local_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["local"]
|
||||
config["is_decoder"] = False
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_local_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
config["local_num_chunks_before"] = 0
|
||||
config["attn_layers"] = ["local"]
|
||||
attn_mask = self._get_attn_mask()
|
||||
hidden_states = self._get_hidden_states()
|
||||
@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
|
||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||
expected_output_slice = torch.tensor(
|
||||
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
|
||||
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user