[Reformer] fix reformer num buckets (#4564)
* fix reformer num buckets * fix * adapt docs * set num buckets in config
This commit is contained in:
committed by
GitHub
parent
3dea40b858
commit
3e3e552125
@@ -62,7 +62,7 @@ For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>
|
||||
|
||||
Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.
|
||||
|
||||
It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly.
|
||||
When training a model from scratch, it is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length a good value for ``num_buckets`` is calculated on the fly. This value will then automatically be saved in the config and should be reused for inference.
|
||||
|
||||
Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
|
||||
|
||||
|
||||
@@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig):
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
num_attention_heads (:obj:`int`, optional, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`):
|
||||
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `None`):
|
||||
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
|
||||
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
|
||||
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length.
|
||||
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. If `num_buckets` is set to `None`, a good value for `num_buckets` is calculated on the fly.
|
||||
num_hashes (:obj:`int`, optional, defaults to 1):
|
||||
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
|
||||
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
|
||||
@@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig):
|
||||
lsh_num_chunks_after=0,
|
||||
max_position_embeddings=4096,
|
||||
num_attention_heads=2,
|
||||
num_buckets=32,
|
||||
num_buckets=None,
|
||||
num_hashes=1,
|
||||
pad_token_id=0,
|
||||
vocab_size=320,
|
||||
|
||||
@@ -283,6 +283,8 @@ class EfficientAttentionMixin:
|
||||
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.chunk_length = config.lsh_attn_chunk_length
|
||||
self.num_hashes = config.num_hashes
|
||||
self.num_buckets = config.num_buckets
|
||||
@@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
return sorted_bucket_idx, undo_sorted_bucket_idx
|
||||
|
||||
def _set_num_buckets(self, sequence_length):
|
||||
# recommended `num_buckets` from paper
|
||||
num_buckets = 2 * sequence_length // self.chunk_length
|
||||
# `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper
|
||||
num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
|
||||
# make sure buckets are power of 2
|
||||
num_buckets = 2 ** num_buckets_pow_2
|
||||
|
||||
# factorize `num_buckets` if `num_buckets` becomes too large
|
||||
num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,)
|
||||
if num_buckets > 2 * num_buckets_limit:
|
||||
num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1]
|
||||
num_buckets_limit = 2 * max(
|
||||
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,
|
||||
)
|
||||
if num_buckets > num_buckets_limit:
|
||||
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
|
||||
|
||||
logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets))
|
||||
|
||||
# set num buckets in config to be properly saved
|
||||
self.config.num_buckets = num_buckets
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
def _attend(
|
||||
|
||||
Reference in New Issue
Block a user