[Reformer] Improved memory if input is shorter than chunk length (#4720)
* improve handling of short inputs for reformer * correct typo in assert statement * fix other tests
This commit is contained in:
committed by
GitHub
parent
b231a413f5
commit
9ca485734a
@@ -351,41 +351,48 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
value_vectors.shape[-1], self.attention_head_size
|
value_vectors.shape[-1], self.attention_head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# set `num_buckets` on the fly, recommended way to do it
|
# LSH attention only makes sense if chunked attention should be performed
|
||||||
if self.num_buckets is None:
|
if self.chunk_length < sequence_length:
|
||||||
self._set_num_buckets(sequence_length)
|
# set `num_buckets` on the fly, recommended way to do it
|
||||||
|
if self.num_buckets is None:
|
||||||
|
self._set_num_buckets(sequence_length)
|
||||||
|
|
||||||
# use cached buckets for backprop only
|
# use cached buckets for backprop only
|
||||||
if buckets is None:
|
if buckets is None:
|
||||||
# hash query key vectors into buckets
|
# hash query key vectors into buckets
|
||||||
buckets = self._hash_vectors(query_key_vectors, num_hashes)
|
buckets = self._hash_vectors(query_key_vectors, num_hashes)
|
||||||
|
|
||||||
assert (
|
|
||||||
int(buckets.shape[-1]) == num_hashes * sequence_length
|
|
||||||
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
|
|
||||||
|
|
||||||
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
|
|
||||||
sequence_length, buckets, num_hashes
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure bucket idx is not longer then sequence length
|
|
||||||
sorted_bucket_idx = sorted_bucket_idx % sequence_length
|
|
||||||
|
|
||||||
# cluster query key value vectors according to hashed buckets
|
|
||||||
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
|
|
||||||
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
|
|
||||||
|
|
||||||
query_key_vectors = self._split_seq_length_dim_to(
|
|
||||||
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
|
||||||
value_vectors = self._split_seq_length_dim_to(
|
|
||||||
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.chunk_length is None:
|
|
||||||
assert (
|
assert (
|
||||||
self.num_chunks_before == 0 and self.num_chunks_after == 0
|
int(buckets.shape[-1]) == num_hashes * sequence_length
|
||||||
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
|
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
|
||||||
|
|
||||||
|
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
|
||||||
|
sequence_length, buckets, num_hashes
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure bucket idx is not longer then sequence length
|
||||||
|
sorted_bucket_idx = sorted_bucket_idx % sequence_length
|
||||||
|
|
||||||
|
# cluster query key value vectors according to hashed buckets
|
||||||
|
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
|
||||||
|
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
|
||||||
|
|
||||||
|
query_key_vectors = self._split_seq_length_dim_to(
|
||||||
|
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
value_vectors = self._split_seq_length_dim_to(
|
||||||
|
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.chunk_length is None:
|
||||||
|
assert (
|
||||||
|
self.num_chunks_before == 0 and self.num_chunks_after == 0
|
||||||
|
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
|
||||||
|
else:
|
||||||
|
# get sequence length indices
|
||||||
|
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
|
||||||
|
batch_size, self.num_attention_heads, 1
|
||||||
|
)
|
||||||
|
|
||||||
# scale key vectors
|
# scale key vectors
|
||||||
key_vectors = self._len_and_dim_norm(query_key_vectors)
|
key_vectors = self._len_and_dim_norm(query_key_vectors)
|
||||||
@@ -398,31 +405,35 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
sorted_bucket_idx=sorted_bucket_idx,
|
sorted_bucket_idx=sorted_bucket_idx,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# free memory
|
# free memory
|
||||||
del query_key_vectors, key_vectors, value_vectors
|
del query_key_vectors, key_vectors, value_vectors
|
||||||
|
|
||||||
# sort clusters back to correct ordering
|
# re-order out_vectors and logits
|
||||||
out_vectors, logits = ReverseSort.apply(
|
if self.chunk_length < sequence_length:
|
||||||
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
|
# sort clusters back to correct ordering
|
||||||
)
|
out_vectors, logits = ReverseSort.apply(
|
||||||
|
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
|
||||||
# sum up all hash rounds
|
|
||||||
if num_hashes > 1:
|
|
||||||
out_vectors = self._split_seq_length_dim_to(
|
|
||||||
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
)
|
||||||
logits = self._split_seq_length_dim_to(
|
|
||||||
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
).unsqueeze(-1)
|
|
||||||
|
|
||||||
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
|
# sum up all hash rounds
|
||||||
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
|
if num_hashes > 1:
|
||||||
|
out_vectors = self._split_seq_length_dim_to(
|
||||||
|
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
logits = self._split_seq_length_dim_to(
|
||||||
|
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
).unsqueeze(-1)
|
||||||
|
|
||||||
|
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
|
||||||
|
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
|
||||||
|
# free memory
|
||||||
|
del probs_vectors
|
||||||
|
|
||||||
# free memory
|
# free memory
|
||||||
del probs_vectors
|
del logits
|
||||||
|
|
||||||
# free memory
|
|
||||||
del logits
|
|
||||||
|
|
||||||
assert out_vectors.shape == (
|
assert out_vectors.shape == (
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -554,10 +565,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
self.num_buckets = num_buckets
|
self.num_buckets = num_buckets
|
||||||
|
|
||||||
def _attend(
|
def _attend(
|
||||||
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask,
|
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
|
||||||
):
|
):
|
||||||
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
|
||||||
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
# look at previous and following chunks if chunked attention
|
||||||
|
if self.chunk_length < sequence_length:
|
||||||
|
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
|
||||||
# get logits and dots
|
# get logits and dots
|
||||||
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
|
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
|
||||||
@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
# free memory
|
# free memory
|
||||||
del query_vectors, key_vectors
|
del query_vectors, key_vectors
|
||||||
|
|
||||||
query_bucket_idx = self._split_seq_length_dim_to(
|
# if chunked attention split bucket idxs to query and key
|
||||||
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
|
if self.chunk_length < sequence_length:
|
||||||
)
|
query_bucket_idx = self._split_seq_length_dim_to(
|
||||||
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
|
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
|
||||||
|
)
|
||||||
|
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
else:
|
||||||
|
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
|
||||||
|
|
||||||
# get correct mask values depending on precision
|
# get correct mask values depending on precision
|
||||||
if query_key_dots.dtype == torch.float16:
|
if query_key_dots.dtype == torch.float16:
|
||||||
@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
self_mask_value = self.self_mask_value_float32
|
self_mask_value = self.self_mask_value_float32
|
||||||
mask_value = self.mask_value_float32
|
mask_value = self.mask_value_float32
|
||||||
|
|
||||||
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask)
|
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
query_key_dots = torch.where(mask, query_key_dots, mask_value)
|
query_key_dots = torch.where(mask, query_key_dots, mask_value)
|
||||||
@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
del value_vectors
|
del value_vectors
|
||||||
|
|
||||||
# merge chunk length
|
# merge chunk length
|
||||||
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
|
if self.chunk_length < sequence_length:
|
||||||
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
|
||||||
|
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
||||||
|
|
||||||
return out_vectors, logits, attention_probs
|
return out_vectors, logits, attention_probs
|
||||||
|
|
||||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask):
|
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
# Causal mask
|
# Causal mask
|
||||||
@@ -642,15 +661,27 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
|
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
|
||||||
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
|
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
|
# if chunked attention, the attention mask has to correspond to LSH order
|
||||||
# expand attn_mask to fit with key_value_bucket_idx shape
|
if sequence_length > self.chunk_length:
|
||||||
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
|
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
|
||||||
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
|
# expand attn_mask to fit with key_value_bucket_idx shape
|
||||||
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
|
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
|
||||||
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
|
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
|
||||||
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
|
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
|
||||||
|
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
|
||||||
|
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
|
||||||
|
|
||||||
|
# free memory
|
||||||
|
del query_attn_mask, key_attn_mask
|
||||||
|
else:
|
||||||
|
# usual attention mask creation
|
||||||
|
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||||
|
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
|
||||||
|
query_indices.shape + attention_mask.shape[-1:]
|
||||||
|
)
|
||||||
|
|
||||||
# free memory
|
# free memory
|
||||||
del query_attn_mask, key_attn_mask, attention_mask
|
del attention_mask
|
||||||
|
|
||||||
# multiply by casaul mask if necessary
|
# multiply by casaul mask if necessary
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@@ -810,36 +841,45 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
|
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
# chunk vectors
|
# get sequence length indices
|
||||||
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
|
|
||||||
query_vectors = self._split_seq_length_dim_to(
|
|
||||||
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
|
||||||
key_vectors = self._split_seq_length_dim_to(
|
|
||||||
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
|
||||||
value_vectors = self._split_seq_length_dim_to(
|
|
||||||
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# chunk indices
|
|
||||||
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
|
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
|
||||||
batch_size, self.num_attention_heads, 1
|
batch_size, self.num_attention_heads, 1
|
||||||
)
|
)
|
||||||
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
|
||||||
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
|
||||||
|
|
||||||
# append chunks before and after
|
# if input should be chunked
|
||||||
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
if self.chunk_length < sequence_length:
|
||||||
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
# chunk vectors
|
||||||
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
|
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
|
||||||
|
query_vectors = self._split_seq_length_dim_to(
|
||||||
|
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
key_vectors = self._split_seq_length_dim_to(
|
||||||
|
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
value_vectors = self._split_seq_length_dim_to(
|
||||||
|
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# chunk indices
|
||||||
|
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
||||||
|
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
|
||||||
|
|
||||||
|
# append chunks before and after
|
||||||
|
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
else:
|
||||||
|
query_indices = key_indices = indices
|
||||||
|
|
||||||
|
# query-key matmul: QK^T
|
||||||
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
|
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
|
||||||
|
|
||||||
# free memory
|
# free memory
|
||||||
del query_vectors, key_vectors
|
del query_vectors, key_vectors
|
||||||
|
|
||||||
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape)
|
mask = self._compute_attn_mask(
|
||||||
|
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# get mask tensor depending on half precision or not
|
# get mask tensor depending on half precision or not
|
||||||
@@ -874,7 +914,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
del value_vectors
|
del value_vectors
|
||||||
|
|
||||||
# merge chunk length
|
# merge chunk length
|
||||||
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
if self.chunk_length < sequence_length:
|
||||||
|
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
|
||||||
|
|
||||||
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
|
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
|
||||||
|
|
||||||
@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
|
|
||||||
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
|
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
|
||||||
|
|
||||||
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape):
|
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
# chunk attention mask and look before and after
|
# chunk attention mask and look before and after
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||||
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
|
||||||
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
|
if self.chunk_length < sequence_length:
|
||||||
|
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
||||||
|
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
|
||||||
|
else:
|
||||||
|
attention_mask_key = attention_mask
|
||||||
|
|
||||||
# Causal mask
|
# Causal mask
|
||||||
if self.is_decoder is True:
|
if self.is_decoder is True:
|
||||||
@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel):
|
|||||||
|
|
||||||
# if needs padding
|
# if needs padding
|
||||||
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
|
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
|
||||||
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0
|
must_pad_to_match_chunk_length = (
|
||||||
|
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length
|
||||||
|
)
|
||||||
|
|
||||||
if must_pad_to_match_chunk_length:
|
if must_pad_to_match_chunk_length:
|
||||||
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
|
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
|
||||||
|
|||||||
@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward(
|
|||||||
assert (
|
assert (
|
||||||
input_tensors[0].shape[chunk_dim] % chunk_size == 0
|
input_tensors[0].shape[chunk_dim] % chunk_size == 0
|
||||||
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
|
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
|
||||||
input_tensors[0][chunk_dim], chunk_size
|
input_tensors[0].shape[chunk_dim], chunk_size
|
||||||
)
|
)
|
||||||
|
|
||||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||||
|
|||||||
@@ -388,6 +388,16 @@ class ReformerModelTester:
|
|||||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
|
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
||||||
|
# force chunk length to be bigger than input_ids
|
||||||
|
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||||
|
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||||
|
model = ReformerModelWithLMHead(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||||
|
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, input_mask,) = config_and_inputs
|
(config, input_ids, input_mask,) = config_and_inputs
|
||||||
@@ -433,6 +443,10 @@ class ReformerTesterMixin:
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_reformer_no_chunking(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_dropout_random_seed_is_changing(self):
|
def test_dropout_random_seed_is_changing(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_lsh_layer_forward(self):
|
def test_lsh_layer_forward(self):
|
||||||
config = self._get_basic_config_and_input()
|
config = self._get_basic_config_and_input()
|
||||||
|
config["lsh_num_chunks_before"] = 0
|
||||||
config["attn_layers"] = ["lsh"]
|
config["attn_layers"] = ["lsh"]
|
||||||
config["is_decoder"] = False
|
config["is_decoder"] = False
|
||||||
hidden_states = self._get_hidden_states()
|
hidden_states = self._get_hidden_states()
|
||||||
@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_lsh_layer_forward_complex(self):
|
def test_lsh_layer_forward_complex(self):
|
||||||
config = self._get_basic_config_and_input()
|
config = self._get_basic_config_and_input()
|
||||||
|
config["lsh_num_chunks_before"] = 0
|
||||||
config["attn_layers"] = ["lsh"]
|
config["attn_layers"] = ["lsh"]
|
||||||
config["num_buckets"] = [2, 4]
|
config["num_buckets"] = [2, 4]
|
||||||
attn_mask = self._get_attn_mask()
|
attn_mask = self._get_attn_mask()
|
||||||
@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_local_layer_forward(self):
|
def test_local_layer_forward(self):
|
||||||
config = self._get_basic_config_and_input()
|
config = self._get_basic_config_and_input()
|
||||||
|
config["local_num_chunks_before"] = 0
|
||||||
config["attn_layers"] = ["local"]
|
config["attn_layers"] = ["local"]
|
||||||
config["is_decoder"] = False
|
config["is_decoder"] = False
|
||||||
hidden_states = self._get_hidden_states()
|
hidden_states = self._get_hidden_states()
|
||||||
@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_local_layer_forward_complex(self):
|
def test_local_layer_forward_complex(self):
|
||||||
config = self._get_basic_config_and_input()
|
config = self._get_basic_config_and_input()
|
||||||
|
config["local_num_chunks_before"] = 0
|
||||||
config["attn_layers"] = ["local"]
|
config["attn_layers"] = ["local"]
|
||||||
attn_mask = self._get_attn_mask()
|
attn_mask = self._get_attn_mask()
|
||||||
hidden_states = self._get_hidden_states()
|
hidden_states = self._get_hidden_states()
|
||||||
@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
|
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
|
||||||
output_slice = reformer_output.hidden_states[0, 0, :5]
|
output_slice = reformer_output.hidden_states[0, 0, :5]
|
||||||
expected_output_slice = torch.tensor(
|
expected_output_slice = torch.tensor(
|
||||||
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
|
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user