[Reformer] fix reformer num buckets (#4564)

* fix reformer num buckets

* fix

* adapt docs

* set num buckets in config
This commit is contained in:
Patrick von Platen
2020-05-25 22:04:45 +02:00
committed by GitHub
parent 3dea40b858
commit 3e3e552125
3 changed files with 18 additions and 9 deletions

View File

@@ -62,7 +62,7 @@ For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>
Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.
It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly.
When training a model from scratch, it is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length a good value for ``num_buckets`` is calculated on the fly. This value will then automatically be saved in the config and should be reused for inference.
Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.