Fix and improve REALM fine-tuning (#15297)
* Draft * Add test * Update src/transformers/models/realm/modeling_realm.py * Apply suggestion * Add block_mask * Update * Update * Add block_embedding_to * Remove no_grad * Use AutoTokenizer * Remove model.to overridding
This commit is contained in:
committed by
GitHub
parent
439de3f7f9
commit
7b3bd1f21a
@@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi
|
|||||||
## RealmForOpenQA
|
## RealmForOpenQA
|
||||||
|
|
||||||
[[autodoc]] RealmForOpenQA
|
[[autodoc]] RealmForOpenQA
|
||||||
|
- block_embedding_to
|
||||||
- forward
|
- forward
|
||||||
@@ -48,6 +48,7 @@ else:
|
|||||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||||
|
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
|||||||
@@ -836,13 +836,13 @@ class RealmReaderProjection(nn.Module):
|
|||||||
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
|
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
def forward(self, hidden_states, token_type_ids):
|
def forward(self, hidden_states, block_mask):
|
||||||
def span_candidates(masks):
|
def span_candidates(masks):
|
||||||
"""
|
"""
|
||||||
Generate span candidates.
|
Generate span candidates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks: <int32> [num_retrievals, max_sequence_len]
|
masks: <bool> [num_retrievals, max_sequence_len]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
|
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
|
||||||
@@ -875,8 +875,7 @@ class RealmReaderProjection(nn.Module):
|
|||||||
hidden_states = self.dense_intermediate(hidden_states)
|
hidden_states = self.dense_intermediate(hidden_states)
|
||||||
# [reader_beam_size, max_sequence_len, span_hidden_size]
|
# [reader_beam_size, max_sequence_len, span_hidden_size]
|
||||||
start_projection, end_projection = hidden_states.chunk(2, dim=-1)
|
start_projection, end_projection = hidden_states.chunk(2, dim=-1)
|
||||||
block_mask = token_type_ids.detach().clone()
|
|
||||||
block_mask[:, -1] = 0
|
|
||||||
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
|
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
|
||||||
|
|
||||||
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
|
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
|
||||||
@@ -1543,6 +1542,7 @@ class RealmReader(RealmPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
relevance_score=None,
|
relevance_score=None,
|
||||||
|
block_mask=None,
|
||||||
start_positions=None,
|
start_positions=None,
|
||||||
end_positions=None,
|
end_positions=None,
|
||||||
has_answers=None,
|
has_answers=None,
|
||||||
@@ -1552,12 +1552,15 @@ class RealmReader(RealmPreTrainedModel):
|
|||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
|
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||||
Relevance score, which must be specified if you want to compute the marginal log loss.
|
Relevance score, which must be specified if you want to compute the logits and marginal log loss.
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
|
||||||
|
The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
|
||||||
|
loss.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
are not taken into account for computing the loss.
|
are not taken into account for computing the loss.
|
||||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
are not taken into account for computing the loss.
|
are not taken into account for computing the loss.
|
||||||
@@ -1570,8 +1573,8 @@ class RealmReader(RealmPreTrainedModel):
|
|||||||
|
|
||||||
if relevance_score is None:
|
if relevance_score is None:
|
||||||
raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
|
raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
|
||||||
if token_type_ids is None:
|
if block_mask is None:
|
||||||
raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.")
|
raise ValueError("You have to specify `block_mask` to separate question block and evidence block.")
|
||||||
if token_type_ids.size(1) < self.config.max_span_width:
|
if token_type_ids.size(1) < self.config.max_span_width:
|
||||||
raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
|
raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
|
||||||
outputs = self.realm(
|
outputs = self.realm(
|
||||||
@@ -1590,7 +1593,9 @@ class RealmReader(RealmPreTrainedModel):
|
|||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
# [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
|
# [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
|
||||||
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids)
|
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(
|
||||||
|
sequence_output, block_mask[0 : self.config.reader_beam_size]
|
||||||
|
)
|
||||||
# [searcher_beam_size, 1]
|
# [searcher_beam_size, 1]
|
||||||
retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
|
retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
|
||||||
# [reader_beam_size, num_candidates]
|
# [reader_beam_size, num_candidates]
|
||||||
@@ -1737,11 +1742,21 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def beam_size(self):
|
def searcher_beam_size(self):
|
||||||
if self.training:
|
if self.training:
|
||||||
return self.config.searcher_beam_size
|
return self.config.searcher_beam_size
|
||||||
return self.config.reader_beam_size
|
return self.config.reader_beam_size
|
||||||
|
|
||||||
|
def block_embedding_to(self, device):
|
||||||
|
"""Send `self.block_emb` to a specific device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (`str` or `torch.device`):
|
||||||
|
The device to which `self.block_emb` will be sent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.block_emb = self.block_emb.to(device)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
|
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1787,36 +1802,37 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
|||||||
question_outputs = self.embedder(
|
question_outputs = self.embedder(
|
||||||
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
|
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# [1, projection_size]
|
# [1, projection_size]
|
||||||
question_projection = question_outputs[0]
|
question_projection = question_outputs[0]
|
||||||
|
|
||||||
|
# CPU computation starts.
|
||||||
# [1, block_emb_size]
|
# [1, block_emb_size]
|
||||||
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection)
|
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device))
|
||||||
# [1, searcher_beam_size]
|
# [1, searcher_beam_size]
|
||||||
_, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1)
|
_, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1)
|
||||||
# [searcher_beam_size]
|
# [searcher_beam_size]
|
||||||
# Must convert to cpu tensor for subsequent numpy operations
|
retrieved_block_ids = retrieved_block_ids.squeeze()
|
||||||
retrieved_block_ids = retrieved_block_ids.squeeze().cpu()
|
# [searcher_beam_size, projection_size]
|
||||||
|
retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids)
|
||||||
|
# CPU computation ends.
|
||||||
|
|
||||||
# Retrieve possible answers
|
# Retrieve possible answers
|
||||||
has_answers, start_pos, end_pos, concat_inputs = self.retriever(
|
has_answers, start_pos, end_pos, concat_inputs = self.retriever(
|
||||||
retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len
|
retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
|
concat_inputs = concat_inputs.to(self.reader.device)
|
||||||
|
block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device)
|
||||||
|
block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool))
|
||||||
|
|
||||||
if has_answers is not None:
|
if has_answers is not None:
|
||||||
has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
|
has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
|
||||||
start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
|
start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
|
||||||
end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
|
end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
|
||||||
|
|
||||||
concat_inputs = concat_inputs.to(self.reader.device)
|
|
||||||
|
|
||||||
# [searcher_beam_size, projection_size]
|
|
||||||
retrieved_block_emb = torch.index_select(
|
|
||||||
self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device)
|
|
||||||
)
|
|
||||||
# [searcher_beam_size]
|
# [searcher_beam_size]
|
||||||
retrieved_logits = torch.einsum(
|
retrieved_logits = torch.einsum(
|
||||||
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device)
|
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
reader_output = self.reader(
|
reader_output = self.reader(
|
||||||
@@ -1824,6 +1840,7 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
|||||||
attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
|
attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
|
||||||
token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
|
token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
|
||||||
relevance_score=retrieved_logits,
|
relevance_score=retrieved_logits,
|
||||||
|
block_mask=block_mask,
|
||||||
has_answers=has_answers,
|
has_answers=has_answers,
|
||||||
start_positions=start_pos,
|
start_positions=start_pos,
|
||||||
end_positions=end_pos,
|
end_positions=end_pos,
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ from typing import Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .tokenization_realm import RealmTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
|
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
|
||||||
@@ -97,7 +97,9 @@ class RealmRetriever:
|
|||||||
text.append(question)
|
text.append(question)
|
||||||
text_pair.append(retrieved_block.decode())
|
text_pair.append(retrieved_block.decode())
|
||||||
|
|
||||||
concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length)
|
concat_inputs = self.tokenizer(
|
||||||
|
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
|
||||||
|
)
|
||||||
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
|
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
|
||||||
|
|
||||||
if answer_ids is not None:
|
if answer_ids is not None:
|
||||||
@@ -115,7 +117,7 @@ class RealmRetriever:
|
|||||||
)
|
)
|
||||||
block_records = np.load(block_records_path, allow_pickle=True)
|
block_records = np.load(block_records_path, allow_pickle=True)
|
||||||
|
|
||||||
tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||||||
|
|
||||||
return cls(block_records, tokenizer)
|
return cls(block_records, tokenizer)
|
||||||
|
|
||||||
@@ -133,13 +135,15 @@ class RealmRetriever:
|
|||||||
max_answers = 0
|
max_answers = 0
|
||||||
|
|
||||||
for input_id in concat_inputs.input_ids:
|
for input_id in concat_inputs.input_ids:
|
||||||
|
input_id_list = input_id.tolist()
|
||||||
|
# Check answers between two [SEP] tokens
|
||||||
|
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
||||||
|
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
|
||||||
|
|
||||||
start_pos.append([])
|
start_pos.append([])
|
||||||
end_pos.append([])
|
end_pos.append([])
|
||||||
input_id_list = input_id.tolist()
|
|
||||||
# Checking answers after the [SEP] token
|
|
||||||
sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
|
||||||
for answer in answer_ids:
|
for answer in answer_ids:
|
||||||
for idx in range(sep_idx, len(input_id)):
|
for idx in range(first_sep_idx + 1, second_sep_idx):
|
||||||
if answer[0] == input_id_list[idx]:
|
if answer[0] == input_id_list[idx]:
|
||||||
if input_id_list[idx : idx + len(answer)] == answer:
|
if input_id_list[idx : idx + len(answer)] == answer:
|
||||||
start_pos[-1].append(idx)
|
start_pos[-1].append(idx)
|
||||||
@@ -158,5 +162,4 @@ class RealmRetriever:
|
|||||||
padded = [-1] * (max_answers - len(start_pos_))
|
padded = [-1] * (max_answers - len(start_pos_))
|
||||||
start_pos_ += padded
|
start_pos_ += padded
|
||||||
end_pos_ += padded
|
end_pos_ += padded
|
||||||
|
|
||||||
return has_answers, start_pos, end_pos
|
return has_answers, start_pos, end_pos
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
||||||
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
||||||
|
|
||||||
def test_retriever(self):
|
def test_scorer(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_scorer(*config_and_inputs)
|
self.model_tester.create_and_check_scorer(*config_and_inputs)
|
||||||
|
|
||||||
@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loss = model(**inputs).reader_output.loss
|
loss = model(**inputs).reader_output.loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
# Test model.block_embedding_to
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model.block_embedding_to(device)
|
||||||
|
loss = model(**inputs).reader_output.loss
|
||||||
|
loss.backward()
|
||||||
|
self.assertEqual(model.block_emb.device.type, device.type)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_embedder_from_pretrained(self):
|
def test_embedder_from_pretrained(self):
|
||||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||||
@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
concat_input_ids = torch.arange(10).view((2, 5))
|
concat_input_ids = torch.arange(10).view((2, 5))
|
||||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
|
||||||
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
|
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
|
||||||
|
|
||||||
output = model(
|
output = model(
|
||||||
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
|
concat_input_ids,
|
||||||
|
token_type_ids=concat_token_type_ids,
|
||||||
|
relevance_score=relevance_score,
|
||||||
|
block_mask=concat_block_mask,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_idx_expected_shape = torch.Size(())
|
block_idx_expected_shape = torch.Size(())
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
|
|||||||
b"This is the third record",
|
b"This is the third record",
|
||||||
b"This is the fourth record",
|
b"This is the fourth record",
|
||||||
b"This is the fifth record",
|
b"This is the fifth record",
|
||||||
|
b"This is a longer longer longer record",
|
||||||
],
|
],
|
||||||
dtype=np.object,
|
dtype=np.object,
|
||||||
)
|
)
|
||||||
@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
|
|||||||
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
|
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
|
||||||
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
|
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
|
||||||
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
|
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
|
||||||
|
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
|
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
|
||||||
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
|
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
|
||||||
@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
|
|||||||
retriever = self.get_dummy_retriever()
|
retriever = self.get_dummy_retriever()
|
||||||
tokenizer = retriever.tokenizer
|
tokenizer = retriever.tokenizer
|
||||||
|
|
||||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
|
||||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||||
answer_ids = tokenizer(
|
answer_ids = tokenizer(
|
||||||
["the fourth"],
|
["the fourth", "longer longer"],
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
return_attention_mask=False,
|
return_attention_mask=False,
|
||||||
@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
|
|||||||
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual([False, True], has_answers)
|
self.assertEqual([False, True, True], has_answers)
|
||||||
self.assertEqual([[-1], [6]], start_pos)
|
self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
|
||||||
self.assertEqual([[-1], [7]], end_pos)
|
self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
|
||||||
|
|
||||||
def test_save_load_pretrained(self):
|
def test_save_load_pretrained(self):
|
||||||
retriever = self.get_dummy_retriever()
|
retriever = self.get_dummy_retriever()
|
||||||
|
|||||||
Reference in New Issue
Block a user