Fix and improve REALM fine-tuning (#15297)
* Draft * Add test * Update src/transformers/models/realm/modeling_realm.py * Apply suggestion * Add block_mask * Update * Update * Add block_embedding_to * Remove no_grad * Use AutoTokenizer * Remove model.to overridding
This commit is contained in:
committed by
GitHub
parent
439de3f7f9
commit
7b3bd1f21a
@@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi
|
||||
## RealmForOpenQA
|
||||
|
||||
[[autodoc]] RealmForOpenQA
|
||||
- block_embedding_to
|
||||
- forward
|
||||
@@ -48,6 +48,7 @@ else:
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
||||
@@ -836,13 +836,13 @@ class RealmReaderProjection(nn.Module):
|
||||
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states, token_type_ids):
|
||||
def forward(self, hidden_states, block_mask):
|
||||
def span_candidates(masks):
|
||||
"""
|
||||
Generate span candidates.
|
||||
|
||||
Args:
|
||||
masks: <int32> [num_retrievals, max_sequence_len]
|
||||
masks: <bool> [num_retrievals, max_sequence_len]
|
||||
|
||||
Returns:
|
||||
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
|
||||
@@ -875,8 +875,7 @@ class RealmReaderProjection(nn.Module):
|
||||
hidden_states = self.dense_intermediate(hidden_states)
|
||||
# [reader_beam_size, max_sequence_len, span_hidden_size]
|
||||
start_projection, end_projection = hidden_states.chunk(2, dim=-1)
|
||||
block_mask = token_type_ids.detach().clone()
|
||||
block_mask[:, -1] = 0
|
||||
|
||||
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
|
||||
|
||||
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
|
||||
@@ -1543,6 +1542,7 @@ class RealmReader(RealmPreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
relevance_score=None,
|
||||
block_mask=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
has_answers=None,
|
||||
@@ -1552,12 +1552,15 @@ class RealmReader(RealmPreTrainedModel):
|
||||
):
|
||||
r"""
|
||||
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||
Relevance score, which must be specified if you want to compute the marginal log loss.
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Relevance score, which must be specified if you want to compute the logits and marginal log loss.
|
||||
block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
|
||||
The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
|
||||
loss.
|
||||
start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
@@ -1570,8 +1573,8 @@ class RealmReader(RealmPreTrainedModel):
|
||||
|
||||
if relevance_score is None:
|
||||
raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
|
||||
if token_type_ids is None:
|
||||
raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.")
|
||||
if block_mask is None:
|
||||
raise ValueError("You have to specify `block_mask` to separate question block and evidence block.")
|
||||
if token_type_ids.size(1) < self.config.max_span_width:
|
||||
raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
|
||||
outputs = self.realm(
|
||||
@@ -1590,7 +1593,9 @@ class RealmReader(RealmPreTrainedModel):
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
|
||||
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids)
|
||||
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(
|
||||
sequence_output, block_mask[0 : self.config.reader_beam_size]
|
||||
)
|
||||
# [searcher_beam_size, 1]
|
||||
retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
|
||||
# [reader_beam_size, num_candidates]
|
||||
@@ -1737,11 +1742,21 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def beam_size(self):
|
||||
def searcher_beam_size(self):
|
||||
if self.training:
|
||||
return self.config.searcher_beam_size
|
||||
return self.config.reader_beam_size
|
||||
|
||||
def block_embedding_to(self, device):
|
||||
"""Send `self.block_emb` to a specific device.
|
||||
|
||||
Args:
|
||||
device (`str` or `torch.device`):
|
||||
The device to which `self.block_emb` will be sent.
|
||||
"""
|
||||
|
||||
self.block_emb = self.block_emb.to(device)
|
||||
|
||||
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
|
||||
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -1787,36 +1802,37 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
||||
question_outputs = self.embedder(
|
||||
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
|
||||
)
|
||||
|
||||
# [1, projection_size]
|
||||
question_projection = question_outputs[0]
|
||||
|
||||
# CPU computation starts.
|
||||
# [1, block_emb_size]
|
||||
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection)
|
||||
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device))
|
||||
# [1, searcher_beam_size]
|
||||
_, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1)
|
||||
_, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1)
|
||||
# [searcher_beam_size]
|
||||
# Must convert to cpu tensor for subsequent numpy operations
|
||||
retrieved_block_ids = retrieved_block_ids.squeeze().cpu()
|
||||
retrieved_block_ids = retrieved_block_ids.squeeze()
|
||||
# [searcher_beam_size, projection_size]
|
||||
retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids)
|
||||
# CPU computation ends.
|
||||
|
||||
# Retrieve possible answers
|
||||
has_answers, start_pos, end_pos, concat_inputs = self.retriever(
|
||||
retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len
|
||||
retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len
|
||||
)
|
||||
|
||||
concat_inputs = concat_inputs.to(self.reader.device)
|
||||
block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device)
|
||||
block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool))
|
||||
|
||||
if has_answers is not None:
|
||||
has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
|
||||
start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
|
||||
end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
|
||||
|
||||
concat_inputs = concat_inputs.to(self.reader.device)
|
||||
|
||||
# [searcher_beam_size, projection_size]
|
||||
retrieved_block_emb = torch.index_select(
|
||||
self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device)
|
||||
)
|
||||
# [searcher_beam_size]
|
||||
retrieved_logits = torch.einsum(
|
||||
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device)
|
||||
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device)
|
||||
)
|
||||
|
||||
reader_output = self.reader(
|
||||
@@ -1824,6 +1840,7 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
||||
attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
|
||||
token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
|
||||
relevance_score=retrieved_logits,
|
||||
block_mask=block_mask,
|
||||
has_answers=has_answers,
|
||||
start_positions=start_pos,
|
||||
end_positions=end_pos,
|
||||
|
||||
@@ -20,9 +20,9 @@ from typing import Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ...utils import logging
|
||||
from .tokenization_realm import RealmTokenizer
|
||||
|
||||
|
||||
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
|
||||
@@ -97,7 +97,9 @@ class RealmRetriever:
|
||||
text.append(question)
|
||||
text_pair.append(retrieved_block.decode())
|
||||
|
||||
concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length)
|
||||
concat_inputs = self.tokenizer(
|
||||
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
|
||||
)
|
||||
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
if answer_ids is not None:
|
||||
@@ -115,7 +117,7 @@ class RealmRetriever:
|
||||
)
|
||||
block_records = np.load(block_records_path, allow_pickle=True)
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||||
|
||||
return cls(block_records, tokenizer)
|
||||
|
||||
@@ -133,13 +135,15 @@ class RealmRetriever:
|
||||
max_answers = 0
|
||||
|
||||
for input_id in concat_inputs.input_ids:
|
||||
input_id_list = input_id.tolist()
|
||||
# Check answers between two [SEP] tokens
|
||||
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
||||
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
|
||||
|
||||
start_pos.append([])
|
||||
end_pos.append([])
|
||||
input_id_list = input_id.tolist()
|
||||
# Checking answers after the [SEP] token
|
||||
sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
||||
for answer in answer_ids:
|
||||
for idx in range(sep_idx, len(input_id)):
|
||||
for idx in range(first_sep_idx + 1, second_sep_idx):
|
||||
if answer[0] == input_id_list[idx]:
|
||||
if input_id_list[idx : idx + len(answer)] == answer:
|
||||
start_pos[-1].append(idx)
|
||||
@@ -158,5 +162,4 @@ class RealmRetriever:
|
||||
padded = [-1] * (max_answers - len(start_pos_))
|
||||
start_pos_ += padded
|
||||
end_pos_ += padded
|
||||
|
||||
return has_answers, start_pos, end_pos
|
||||
|
||||
@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
||||
|
||||
def test_retriever(self):
|
||||
def test_scorer(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_scorer(*config_and_inputs)
|
||||
|
||||
@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
loss = model(**inputs).reader_output.loss
|
||||
loss.backward()
|
||||
|
||||
# Test model.block_embedding_to
|
||||
device = torch.device("cpu")
|
||||
model.block_embedding_to(device)
|
||||
loss = model(**inputs).reader_output.loss
|
||||
loss.backward()
|
||||
self.assertEqual(model.block_emb.device.type, device.type)
|
||||
|
||||
@slow
|
||||
def test_embedder_from_pretrained(self):
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
concat_input_ids = torch.arange(10).view((2, 5))
|
||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
|
||||
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
|
||||
|
||||
output = model(
|
||||
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
|
||||
concat_input_ids,
|
||||
token_type_ids=concat_token_type_ids,
|
||||
relevance_score=relevance_score,
|
||||
block_mask=concat_block_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
block_idx_expected_shape = torch.Size(())
|
||||
|
||||
@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
|
||||
b"This is the third record",
|
||||
b"This is the fourth record",
|
||||
b"This is the fifth record",
|
||||
b"This is a longer longer longer record",
|
||||
],
|
||||
dtype=np.object,
|
||||
)
|
||||
@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
|
||||
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
|
||||
self.assertEqual(
|
||||
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
|
||||
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
|
||||
@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
|
||||
retriever = self.get_dummy_retriever()
|
||||
tokenizer = retriever.tokenizer
|
||||
|
||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
||||
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
|
||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||
answer_ids = tokenizer(
|
||||
["the fourth"],
|
||||
["the fourth", "longer longer"],
|
||||
add_special_tokens=False,
|
||||
return_token_type_ids=False,
|
||||
return_attention_mask=False,
|
||||
@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
|
||||
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
||||
)
|
||||
|
||||
self.assertEqual([False, True], has_answers)
|
||||
self.assertEqual([[-1], [6]], start_pos)
|
||||
self.assertEqual([[-1], [7]], end_pos)
|
||||
self.assertEqual([False, True, True], has_answers)
|
||||
self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
|
||||
self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
|
||||
|
||||
def test_save_load_pretrained(self):
|
||||
retriever = self.get_dummy_retriever()
|
||||
|
||||
Reference in New Issue
Block a user