Support various BERT relative position embeddings (2nd) (#8276)
* Support BERT relative position embeddings * Fix typo in README.md * Address review comment * Fix failing tests * [tiny] Fix style_doc.py check by adding an empty line to configuration_bert.py * make fix copies * fix configs of electra and albert and fix longformer * remove copy statement from longformer * fix albert * fix electra * Add bert variants forward tests for various position embeddings * [tiny] Fix style for test_modeling_bert.py * improve docstring * [tiny] improve docstring and remove unnecessary dependency * [tiny] Remove unused import * re-add to ALBERT * make embeddings work for ALBERT * add test for albert Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -404,6 +404,12 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
@@ -477,3 +483,43 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = BertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = BertModel.from_pretrained("bert-base-uncased")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_no_head_relative_embedding_key(self):
|
||||
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.3492, 0.4126, -0.1484], [0.2274, -0.0549, 0.1623], [0.5889, 0.6797, -0.0189]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_no_head_relative_embedding_key_query(self):
|
||||
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor([[[1.1677, 0.5129, 0.9524], [0.6659, 0.5958, 0.6688], [1.1714, 0.1764, 0.6266]]])
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user