Fix DPRReaderTokenizer's attention_mask (#9663)
* Fix the attention_mask in DPRReaderTokenizer * Add an integration test for DPRReader inference * Run make style
This commit is contained in:
@@ -251,7 +251,9 @@ class CustomDPRReaderTokenizerMixin:
|
||||
]
|
||||
}
|
||||
if return_attention_mask is not False:
|
||||
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
|
||||
attention_mask = []
|
||||
for input_ids in encoded_inputs["input_ids"]:
|
||||
attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@@ -252,7 +252,9 @@ class CustomDPRReaderTokenizerMixin:
|
||||
]
|
||||
}
|
||||
if return_attention_mask is not False:
|
||||
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
|
||||
attention_mask = []
|
||||
for input_ids in encoded_inputs["input_ids"]:
|
||||
attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader, DPRReaderTokenizer
|
||||
from transformers.models.dpr.modeling_dpr import (
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -260,3 +260,33 @@ class DPRModelIntegrationTest(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :10], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_reader_inference(self):
|
||||
tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
|
||||
model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
|
||||
|
||||
encoded_inputs = tokenizer(
|
||||
questions="What is love ?",
|
||||
titles="Haddaway",
|
||||
texts="What Is Love is a song recorded by the artist Haddaway",
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = model(**encoded_inputs)
|
||||
|
||||
# compare the actual values for a slice.
|
||||
expected_start_logits = torch.tensor(
|
||||
[[-10.3005, -10.7765, -11.4872, -11.6841, -11.9312, -10.3002, -9.8544, -11.7378, -12.0821, -10.2975]],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
expected_end_logits = torch.tensor(
|
||||
[[-11.0684, -11.7041, -11.5397, -10.3465, -10.8791, -6.8443, -11.9959, -11.0364, -10.0096, -6.8405]],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(outputs.start_logits[:, :10], expected_start_logits, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs.end_logits[:, :10], expected_end_logits, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user