[Reformer] fix reformer num buckets (#4564)
* fix reformer num buckets * fix * adapt docs * set num buckets in config
This commit is contained in:
committed by
GitHub
parent
3dea40b858
commit
3e3e552125
@@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig):
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
num_attention_heads (:obj:`int`, optional, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`):
|
||||
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `None`):
|
||||
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
|
||||
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
|
||||
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length.
|
||||
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. If `num_buckets` is set to `None`, a good value for `num_buckets` is calculated on the fly.
|
||||
num_hashes (:obj:`int`, optional, defaults to 1):
|
||||
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
|
||||
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
|
||||
@@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig):
|
||||
lsh_num_chunks_after=0,
|
||||
max_position_embeddings=4096,
|
||||
num_attention_heads=2,
|
||||
num_buckets=32,
|
||||
num_buckets=None,
|
||||
num_hashes=1,
|
||||
pad_token_id=0,
|
||||
vocab_size=320,
|
||||
|
||||
@@ -283,6 +283,8 @@ class EfficientAttentionMixin:
|
||||
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.chunk_length = config.lsh_attn_chunk_length
|
||||
self.num_hashes = config.num_hashes
|
||||
self.num_buckets = config.num_buckets
|
||||
@@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
return sorted_bucket_idx, undo_sorted_bucket_idx
|
||||
|
||||
def _set_num_buckets(self, sequence_length):
|
||||
# recommended `num_buckets` from paper
|
||||
num_buckets = 2 * sequence_length // self.chunk_length
|
||||
# `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper
|
||||
num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
|
||||
# make sure buckets are power of 2
|
||||
num_buckets = 2 ** num_buckets_pow_2
|
||||
|
||||
# factorize `num_buckets` if `num_buckets` becomes too large
|
||||
num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,)
|
||||
if num_buckets > 2 * num_buckets_limit:
|
||||
num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1]
|
||||
num_buckets_limit = 2 * max(
|
||||
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,
|
||||
)
|
||||
if num_buckets > num_buckets_limit:
|
||||
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
|
||||
|
||||
logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets))
|
||||
|
||||
# set num buckets in config to be properly saved
|
||||
self.config.num_buckets = num_buckets
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
def _attend(
|
||||
|
||||
Reference in New Issue
Block a user