From 917dbb15e07398b119d49449032b2224a7817eb7 Mon Sep 17 00:00:00 2001 From: Sergey Mkrtchyan Date: Tue, 19 Jan 2021 05:43:11 -0500 Subject: [PATCH] Fix DPRReaderTokenizer's attention_mask (#9663) * Fix the attention_mask in DPRReaderTokenizer * Add an integration test for DPRReader inference * Run make style --- .../models/dpr/tokenization_dpr.py | 4 ++- .../models/dpr/tokenization_dpr_fast.py | 4 ++- tests/test_modeling_dpr.py | 32 ++++++++++++++++++- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dpr/tokenization_dpr.py b/src/transformers/models/dpr/tokenization_dpr.py index 253167791e..73c4010a28 100644 --- a/src/transformers/models/dpr/tokenization_dpr.py +++ b/src/transformers/models/dpr/tokenization_dpr.py @@ -251,7 +251,9 @@ class CustomDPRReaderTokenizerMixin: ] } if return_attention_mask is not False: - attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]] + attention_mask = [] + for input_ids in encoded_inputs["input_ids"]: + attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids]) encoded_inputs["attention_mask"] = attention_mask return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors) diff --git a/src/transformers/models/dpr/tokenization_dpr_fast.py b/src/transformers/models/dpr/tokenization_dpr_fast.py index bd8d1fecd3..f67e75594b 100644 --- a/src/transformers/models/dpr/tokenization_dpr_fast.py +++ b/src/transformers/models/dpr/tokenization_dpr_fast.py @@ -252,7 +252,9 @@ class CustomDPRReaderTokenizerMixin: ] } if return_attention_mask is not False: - attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]] + attention_mask = [] + for input_ids in encoded_inputs["input_ids"]: + attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids]) encoded_inputs["attention_mask"] = attention_mask return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors) diff --git a/tests/test_modeling_dpr.py b/tests/test_modeling_dpr.py index be27fe8807..f31563670a 100644 --- a/tests/test_modeling_dpr.py +++ b/tests/test_modeling_dpr.py @@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention if is_torch_available(): import torch - from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader + from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader, DPRReaderTokenizer from transformers.models.dpr.modeling_dpr import ( DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -260,3 +260,33 @@ class DPRModelIntegrationTest(unittest.TestCase): device=torch_device, ) self.assertTrue(torch.allclose(output[:, :10], expected_slice, atol=1e-4)) + + @slow + def test_reader_inference(self): + tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") + model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base") + + encoded_inputs = tokenizer( + questions="What is love ?", + titles="Haddaway", + texts="What Is Love is a song recorded by the artist Haddaway", + padding=True, + return_tensors="pt", + ) + + outputs = model(**encoded_inputs) + + # compare the actual values for a slice. + expected_start_logits = torch.tensor( + [[-10.3005, -10.7765, -11.4872, -11.6841, -11.9312, -10.3002, -9.8544, -11.7378, -12.0821, -10.2975]], + dtype=torch.float, + device=torch_device, + ) + + expected_end_logits = torch.tensor( + [[-11.0684, -11.7041, -11.5397, -10.3465, -10.8791, -6.8443, -11.9959, -11.0364, -10.0096, -6.8405]], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue(torch.allclose(outputs.start_logits[:, :10], expected_start_logits, atol=1e-4)) + self.assertTrue(torch.allclose(outputs.end_logits[:, :10], expected_end_logits, atol=1e-4))