From 3e3e552125e86824239e445dd3c659df0aea4db9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 25 May 2020 22:04:45 +0200 Subject: [PATCH] [Reformer] fix reformer num buckets (#4564) * fix reformer num buckets * fix * adapt docs * set num buckets in config --- docs/source/model_doc/reformer.rst | 2 +- src/transformers/configuration_reformer.py | 6 +++--- src/transformers/modeling_reformer.py | 19 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index 781e203ce7..2c3b3faf1c 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -62,7 +62,7 @@ For more information, see the `original Paper Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory. -It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly. +When training a model from scratch, it is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length a good value for ``num_buckets`` is calculated on the fly. This value will then automatically be saved in the config and should be reused for inference. Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 572fa58fac..04b50ffffc 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig): Typically set this to something large just in case (e.g., 512 or 1024 or 2048). num_attention_heads (:obj:`int`, optional, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. - num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`): + num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `None`): Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`. The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors. - The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. + The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. If `num_buckets` is set to `None`, a good value for `num_buckets` is calculated on the fly. num_hashes (:obj:`int`, optional, defaults to 1): Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme. The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes. @@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig): lsh_num_chunks_after=0, max_position_embeddings=4096, num_attention_heads=2, - num_buckets=32, + num_buckets=None, num_hashes=1, pad_token_id=0, vocab_size=320, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 307d34df80..aa3385a9b9 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -283,6 +283,8 @@ class EfficientAttentionMixin: class LSHSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config): super().__init__() + self.config = config + self.chunk_length = config.lsh_attn_chunk_length self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets @@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): return sorted_bucket_idx, undo_sorted_bucket_idx def _set_num_buckets(self, sequence_length): - # recommended `num_buckets` from paper - num_buckets = 2 * sequence_length // self.chunk_length + # `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper + num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1 + # make sure buckets are power of 2 + num_buckets = 2 ** num_buckets_pow_2 # factorize `num_buckets` if `num_buckets` becomes too large - num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,) - if num_buckets > 2 * num_buckets_limit: - num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] + num_buckets_limit = 2 * max( + int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length, + ) + if num_buckets > num_buckets_limit: + num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)] logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) + + # set num buckets in config to be properly saved + self.config.num_buckets = num_buckets self.num_buckets = num_buckets def _attend(