Fix and improve REALM fine-tuning (#15297)
* Draft * Add test * Update src/transformers/models/realm/modeling_realm.py * Apply suggestion * Add block_mask * Update * Update * Add block_embedding_to * Remove no_grad * Use AutoTokenizer * Remove model.to overridding
This commit is contained in:
committed by
GitHub
parent
439de3f7f9
commit
7b3bd1f21a
@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
||||
|
||||
def test_retriever(self):
|
||||
def test_scorer(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_scorer(*config_and_inputs)
|
||||
|
||||
@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
loss = model(**inputs).reader_output.loss
|
||||
loss.backward()
|
||||
|
||||
# Test model.block_embedding_to
|
||||
device = torch.device("cpu")
|
||||
model.block_embedding_to(device)
|
||||
loss = model(**inputs).reader_output.loss
|
||||
loss.backward()
|
||||
self.assertEqual(model.block_emb.device.type, device.type)
|
||||
|
||||
@slow
|
||||
def test_embedder_from_pretrained(self):
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
concat_input_ids = torch.arange(10).view((2, 5))
|
||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
|
||||
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
|
||||
|
||||
output = model(
|
||||
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
|
||||
concat_input_ids,
|
||||
token_type_ids=concat_token_type_ids,
|
||||
relevance_score=relevance_score,
|
||||
block_mask=concat_block_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
block_idx_expected_shape = torch.Size(())
|
||||
|
||||
@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
|
||||
b"This is the third record",
|
||||
b"This is the fourth record",
|
||||
b"This is the fifth record",
|
||||
b"This is a longer longer longer record",
|
||||
],
|
||||
dtype=np.object,
|
||||
)
|
||||
@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
|
||||
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
|
||||
self.assertEqual(
|
||||
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
|
||||
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
|
||||
@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
|
||||
retriever = self.get_dummy_retriever()
|
||||
tokenizer = retriever.tokenizer
|
||||
|
||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
||||
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
|
||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||
answer_ids = tokenizer(
|
||||
["the fourth"],
|
||||
["the fourth", "longer longer"],
|
||||
add_special_tokens=False,
|
||||
return_token_type_ids=False,
|
||||
return_attention_mask=False,
|
||||
@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
|
||||
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
||||
)
|
||||
|
||||
self.assertEqual([False, True], has_answers)
|
||||
self.assertEqual([[-1], [6]], start_pos)
|
||||
self.assertEqual([[-1], [7]], end_pos)
|
||||
self.assertEqual([False, True, True], has_answers)
|
||||
self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
|
||||
self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
|
||||
|
||||
def test_save_load_pretrained(self):
|
||||
retriever = self.get_dummy_retriever()
|
||||
|
||||
Reference in New Issue
Block a user