Fix and improve REALM fine-tuning (#15297)

* Draft

* Add test

* Update src/transformers/models/realm/modeling_realm.py

* Apply suggestion

* Add block_mask

* Update

* Update

* Add block_embedding_to

* Remove no_grad

* Use AutoTokenizer

* Remove model.to overridding
This commit is contained in:
Li-Huai (Allan) Lin
2022-03-03 21:10:15 +08:00
committed by GitHub
parent 439de3f7f9
commit 7b3bd1f21a
6 changed files with 75 additions and 39 deletions

View File

@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_embedder(*config_and_inputs)
self.model_tester.create_and_check_encoder(*config_and_inputs)
def test_retriever(self):
def test_scorer(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_scorer(*config_and_inputs)
@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).reader_output.loss
loss.backward()
# Test model.block_embedding_to
device = torch.device("cpu")
model.block_embedding_to(device)
loss = model(**inputs).reader_output.loss
loss.backward()
self.assertEqual(model.block_emb.device.type, device.type)
@slow
def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
output = model(
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
concat_input_ids,
token_type_ids=concat_token_type_ids,
relevance_score=relevance_score,
block_mask=concat_block_mask,
return_dict=True,
)
block_idx_expected_shape = torch.Size(())

View File

@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
b"This is the third record",
b"This is the fourth record",
b"This is the fifth record",
b"This is a longer longer longer record",
],
dtype=np.object,
)
@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
self.assertEqual(
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer
retrieved_block_ids = np.array([0, 3], dtype=np.long)
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth"],
["the fourth", "longer longer"],
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
)
self.assertEqual([False, True], has_answers)
self.assertEqual([[-1], [6]], start_pos)
self.assertEqual([[-1], [7]], end_pos)
self.assertEqual([False, True, True], has_answers)
self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
def test_save_load_pretrained(self):
retriever = self.get_dummy_retriever()