[Reformer] fix reformer num buckets (#4564)

* fix reformer num buckets

* fix

* adapt docs

* set num buckets in config
This commit is contained in:
Patrick von Platen
2020-05-25 22:04:45 +02:00
committed by GitHub
parent 3dea40b858
commit 3e3e552125
3 changed files with 18 additions and 9 deletions

View File

@@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig):
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`):
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `None`):
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length.
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. If `num_buckets` is set to `None`, a good value for `num_buckets` is calculated on the fly.
num_hashes (:obj:`int`, optional, defaults to 1):
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
@@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig):
lsh_num_chunks_after=0,
max_position_embeddings=4096,
num_attention_heads=2,
num_buckets=32,
num_buckets=None,
num_hashes=1,
pad_token_id=0,
vocab_size=320,

View File

@@ -283,6 +283,8 @@ class EfficientAttentionMixin:
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.chunk_length = config.lsh_attn_chunk_length
self.num_hashes = config.num_hashes
self.num_buckets = config.num_buckets
@@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return sorted_bucket_idx, undo_sorted_bucket_idx
def _set_num_buckets(self, sequence_length):
# recommended `num_buckets` from paper
num_buckets = 2 * sequence_length // self.chunk_length
# `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper
num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
# make sure buckets are power of 2
num_buckets = 2 ** num_buckets_pow_2
# factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,)
if num_buckets > 2 * num_buckets_limit:
num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1]
num_buckets_limit = 2 * max(
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,
)
if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets))
# set num buckets in config to be properly saved
self.config.num_buckets = num_buckets
self.num_buckets = num_buckets
def _attend(