Support various BERT relative position embeddings (2nd) (#8276)

* Support BERT relative position embeddings

* Fix typo in README.md

* Address review comment

* Fix failing tests

* [tiny] Fix style_doc.py check by adding an empty line to configuration_bert.py

* make fix copies

* fix configs of electra and albert and fix longformer

* remove copy statement from longformer

* fix albert

* fix electra

* Add bert variants forward tests for various position embeddings

* [tiny] Fix style for test_modeling_bert.py

* improve docstring

* [tiny] improve docstring and remove unnecessary dependency

* [tiny] Remove unused import

* re-add to ALBERT

* make embeddings work for ALBERT

* add test for albert

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
zhiheng-huang
2020-11-24 05:40:53 -08:00
committed by GitHub
parent 9e71aa2f8f
commit 2c83b3c38d
16 changed files with 327 additions and 33 deletions

View File

@@ -91,6 +91,13 @@ class BertConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
Examples::
@@ -123,6 +130,7 @@ class BertConfig(PretrainedConfig):
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -140,3 +148,4 @@ class BertConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type