[Longformer] Better handling of global attention mask vs local attention mask (#4672)

* better api

* improve automatic setting of global attention mask

* fix longformer bug

* fix global attention mask in test

* fix global attn mask flatten

* fix slow tests

* update docstring

* update docs and make more robust

* improve attention mask
This commit is contained in:
Patrick von Platen
2020-05-29 17:58:42 +02:00
committed by GitHub
parent e2230ba77b
commit 56ee2560be
3 changed files with 114 additions and 46 deletions

View File

@@ -21,7 +21,7 @@ A selecetd few tokens attend "globally" to all other tokens, as it is convention
Note that "locally" and "globally" attending tokens are projected by different query, key and value matrices.
Also note that every "locally" attending token not only attends to tokens within its window :math:`w`, but also to all "globally" attending tokens so that global attention is *symmetric*.
The user can define which tokens are masked, which tokens attend "locally" and which tokens attend "globally" by setting the `config.attention_mask` `torch.Tensor` appropriately. In contrast to other models `Longformer` accepts the following values in `config.attention_mask`: `0` - the token is masked and not attended at all (as is done in other models), `1` - the token attends "locally", `2` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
The user can define which tokens attend "locally" and which tokens attend "globally" by setting the tensor `global_attention_mask` at run-time appropriately. `Longformer` employs the following logic for `global_attention_mask`: `0` - the token attends "locally", `1` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
Using Longformer self attention, the memory and time complexity of the query-key matmul operation, which usually represents the memory and time bottleneck, can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times w)`, with :math:`n_s` being the sequence length and :math:`w` being the average window size. It is assumed that the number of "globally" attending tokens is insignificant as compared to the number of "locally" attending tokens.