Support various BERT relative position embeddings (2nd) (#8276)

* Support BERT relative position embeddings

* Fix typo in README.md

* Address review comment

* Fix failing tests

* [tiny] Fix style_doc.py check by adding an empty line to configuration_bert.py

* make fix copies

* fix configs of electra and albert and fix longformer

* remove copy statement from longformer

* fix albert

* fix electra

* Add bert variants forward tests for various position embeddings

* [tiny] Fix style for test_modeling_bert.py

* improve docstring

* [tiny] improve docstring and remove unnecessary dependency

* [tiny] Remove unused import

* re-add to ALBERT

* make embeddings work for ALBERT

* add test for albert

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
zhiheng-huang
2020-11-24 05:40:53 -08:00
committed by GitHub
parent 9e71aa2f8f
commit 2c83b3c38d
16 changed files with 327 additions and 33 deletions

View File

@@ -17,7 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
@@ -295,6 +295,12 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
@@ -395,8 +401,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
@require_sentencepiece
@require_tokenizers
@require_torch
class RobertaModelIntegrationTest(unittest.TestCase):
@slow