Add DPR model (#5279)
* beginning of dpr modeling * wip * implement forward * remove biencoder + better init weights * export dpr model to embed model for nlp lib * add new api * remove old code * make style * fix dumb typo * don't load bert weights * docs * docs * style * move the `k` parameter * fix init_weights * add pretrained configs * minor * update config names * style * better config * style * clean code based on PR comments * change Dpr to DPR * fix config * switch encoder config to a dict * style * inheritance -> composition * add messages in assert startements * add dpr reader tokenizer * one tokenizer per model * fix base_model_prefix * fix imports * typo * add convert script * docs * change tokenizers conf names * style * change tokenizers conf names * minor * minor * fix wrong names * minor * remove unused convert functions * rename convert script * use return_tensors in tokenizers * remove n_questions dim * move generate logic to tokenizer * style * add docs * docs * quality * docs * add tests * style * add tokenization tests * DPR full tests * Stay true to the attention mask building * update docs * missing param in bert input docs * docs * style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -121,7 +121,10 @@ conversion utilities for the following models:
|
||||
trained using `OPUS <http://opus.nlpl.eu/>`_ pretrained_models data by Jörg Tiedemann.
|
||||
21. `Longformer <https://github.com/allenai/longformer>`_ (from AllenAI) released with the paper `Longformer: The
|
||||
Long-Document Transformer <https://arxiv.org/abs/2004.05150>`_ by Iz Beltagy, Matthew E. Peters, and Arman Cohan.
|
||||
22. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
22. `DPR <https://github.com/facebookresearch/DPR>`_ (from Facebook) released with the paper `Dense Passage Retrieval
|
||||
for Open-Domain Question Answering <https://arxiv.org/abs/2004.04906>`_ by Vladimir Karpukhin, Barlas Oğuz, Sewon
|
||||
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
23. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
<https://huggingface.co/users>`_.
|
||||
|
||||
.. toctree::
|
||||
@@ -199,3 +202,4 @@ conversion utilities for the following models:
|
||||
model_doc/longformer
|
||||
model_doc/retribert
|
||||
model_doc/mobilebert
|
||||
model_doc/dpr
|
||||
|
||||
89
docs/source/model_doc/dpr.rst
Normal file
89
docs/source/model_doc/dpr.rst
Normal file
@@ -0,0 +1,89 @@
|
||||
DPR
|
||||
----------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Dense Passage Retrieval (DPR) - is a set of tools and models for state-of-the-art open-domain Q&A research.
|
||||
It is based on the following paper:
|
||||
|
||||
Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih, Dense Passage Retrieval for Open-Domain Question Answering.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Open-domain question answering relies on efficient passage retrieval to select candidate contexts, where traditional
|
||||
sparse vector space models, such as TF-IDF or BM25, are the de facto method. In this work, we show that retrieval can
|
||||
be practically implemented using dense representations alone, where embeddings are learned from a small number of
|
||||
questions and passages by a simple dual-encoder framework. When evaluated on a wide range of open-domain QA datasets,
|
||||
our dense retriever outperforms a strong Lucene-BM25 system largely by 9%-19% absolute in terms of top-20 passage
|
||||
retrieval accuracy, and helps our end-to-end QA system establish new state-of-the-art on multiple open-domain QA
|
||||
benchmarks.*
|
||||
|
||||
The original code can be found `here <https://github.com/facebookresearch/DPR>`_.
|
||||
|
||||
|
||||
DPRConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRConfig
|
||||
:members:
|
||||
|
||||
|
||||
DPRContextEncoderTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRContextEncoderTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
DPRContextEncoderTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRContextEncoderTokenizerFast
|
||||
:members:
|
||||
|
||||
DPRQuestionEncoderTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRQuestionEncoderTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
DPRQuestionEncoderTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRQuestionEncoderTokenizerFast
|
||||
:members:
|
||||
|
||||
DPRReaderTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRReaderTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
DPRReaderTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRReaderTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
DPRContextEncoder
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRContextEncoder
|
||||
:members:
|
||||
|
||||
DPRQuestionEncoder
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRQuestionEncoder
|
||||
:members:
|
||||
|
||||
|
||||
DPRReader
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRReader
|
||||
:members:
|
||||
@@ -27,6 +27,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
@@ -129,6 +130,14 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_dpr import (
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
DPRQuestionEncoderTokenizer,
|
||||
DPRQuestionEncoderTokenizerFast,
|
||||
DPRReaderTokenizer,
|
||||
DPRReaderTokenizerFast,
|
||||
)
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
@@ -382,6 +391,14 @@ if is_torch_available():
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_dpr import (
|
||||
DPRPretrainedContextEncoder,
|
||||
DPRPretrainedQuestionEncoder,
|
||||
DPRPretrainedReader,
|
||||
DPRContextEncoder,
|
||||
DPRQuestionEncoder,
|
||||
DPRReader,
|
||||
)
|
||||
from .modeling_retribert import (
|
||||
RetriBertPreTrainedModel,
|
||||
RetriBertModel,
|
||||
|
||||
49
src/transformers/configuration_dpr.py
Normal file
49
src/transformers/configuration_dpr.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2010, DPR authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" DPR model configuration """
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json",
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-single-nq-base/config.json",
|
||||
}
|
||||
|
||||
|
||||
class DPRConfig(BertConfig):
|
||||
r"""
|
||||
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
|
||||
`DPRModel`.
|
||||
|
||||
This is the configuration class to store the configuration of a `DPRContextEncoder`, `DPRQuestionEncoder`, or a `DPRReader`.
|
||||
It is used to instantiate the components of the DPR model.
|
||||
|
||||
Args:
|
||||
projection_dim (:obj:`int`, optional, defaults to 0):
|
||||
Dimension of the projection for the context and question encoders.
|
||||
If it is set to zero (default), then no projection is done.
|
||||
"""
|
||||
model_type = "dpr"
|
||||
|
||||
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection
|
||||
super().__init__(**kwargs)
|
||||
self.projection_dim = projection_dim
|
||||
120
src/transformers/convert_dpr_original_checkpoint_to_pytorch.py
Normal file
120
src/transformers/convert_dpr_original_checkpoint_to_pytorch.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import collections
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
|
||||
|
||||
CheckpointState = collections.namedtuple(
|
||||
"CheckpointState", ["model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params"]
|
||||
)
|
||||
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
print("Reading saved model from %s", model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu"))
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
|
||||
class DPRState:
|
||||
def __init__(self, src_file: Path):
|
||||
self.src_file = src_file
|
||||
|
||||
def load_dpr_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def from_type(comp_type: str, *args, **kwargs) -> "DPRState":
|
||||
if comp_type.startswith("c"):
|
||||
return DPRContextEncoderState(*args, **kwargs)
|
||||
if comp_type.startswith("q"):
|
||||
return DPRQuestionEncoderState(*args, **kwargs)
|
||||
if comp_type.startswith("r"):
|
||||
return DPRReaderState(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Component type must be either 'ctx_encoder', 'question_encoder' or 'reader'.")
|
||||
|
||||
|
||||
class DPRContextEncoderState(DPRState):
|
||||
def load_dpr_model(self):
|
||||
model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
|
||||
print("Loading DPR biencoder from {}".format(self.src_file))
|
||||
saved_state = load_states_from_checkpoint(self.src_file)
|
||||
encoder, prefix = model.ctx_encoder, "ctx_model."
|
||||
state_dict = {}
|
||||
for key, value in saved_state.model_dict.items():
|
||||
if key.startswith(prefix):
|
||||
key = key[len(prefix) :]
|
||||
if not key.startswith("encode_proj."):
|
||||
key = "bert_model." + key
|
||||
state_dict[key] = value
|
||||
encoder.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
class DPRQuestionEncoderState(DPRState):
|
||||
def load_dpr_model(self):
|
||||
model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
|
||||
print("Loading DPR biencoder from {}".format(self.src_file))
|
||||
saved_state = load_states_from_checkpoint(self.src_file)
|
||||
encoder, prefix = model.question_encoder, "question_model."
|
||||
state_dict = {}
|
||||
for key, value in saved_state.model_dict.items():
|
||||
if key.startswith(prefix):
|
||||
key = key[len(prefix) :]
|
||||
if not key.startswith("encode_proj."):
|
||||
key = "bert_model." + key
|
||||
state_dict[key] = value
|
||||
encoder.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
class DPRReaderState(DPRState):
|
||||
def load_dpr_model(self):
|
||||
model = DPRReader(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
|
||||
print("Loading DPR reader from {}".format(self.src_file))
|
||||
saved_state = load_states_from_checkpoint(self.src_file)
|
||||
state_dict = {}
|
||||
for key, value in saved_state.model_dict.items():
|
||||
if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"):
|
||||
key = "encoder.bert_model." + key[len("encoder.") :]
|
||||
state_dict[key] = value
|
||||
model.span_predictor.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def convert(comp_type: str, src_file: Path, dest_dir: Path):
|
||||
dest_dir = Path(dest_dir)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
|
||||
dpr_state = DPRState.from_type(comp_type, src_file=src_file)
|
||||
model = dpr_state.load_dpr_model()
|
||||
model.save_pretrained(dest_dir)
|
||||
model.from_pretrained(dest_dir) # sanity check
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--type", type=str, help="Type of the component to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--src",
|
||||
type=str,
|
||||
help="Path to the dpr checkpoint file. They can be downloaded from the official DPR repo https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the 'retriever' checkpoints.",
|
||||
)
|
||||
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
src_file = Path(args.src)
|
||||
dest_dir = f"converted-{src_file.name}" if args.dest is None else args.dest
|
||||
dest_dir = Path(dest_dir)
|
||||
assert src_file.exists()
|
||||
assert (
|
||||
args.type is not None
|
||||
), "Please specify the component type of the DPR model to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
|
||||
convert(args.type, src_file, dest_dir)
|
||||
@@ -617,6 +617,8 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the hidden states tensors of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
541
src/transformers/modeling_dpr.py
Normal file
541
src/transformers/modeling_dpr.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 DPR Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch DPR model for Open Domain Question Answering."""
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_dpr import DPRConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_bert import BertModel
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-ctx_encoder-single-nq-base",
|
||||
]
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-question_encoder-single-nq-base",
|
||||
]
|
||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-reader-single-nq-base",
|
||||
]
|
||||
|
||||
|
||||
class DPREncoder(PreTrainedModel):
|
||||
|
||||
base_model_prefix = "bert_model"
|
||||
|
||||
def __init__(self, config: DPRConfig):
|
||||
super().__init__(config)
|
||||
self.bert_model = BertModel(config)
|
||||
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
||||
self.projection_dim = config.projection_dim
|
||||
if self.projection_dim > 0:
|
||||
self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> Tuple[Tensor, ...]:
|
||||
outputs = self.bert_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
sequence_output, pooled_output, hidden_states = outputs[:3]
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
if self.projection_dim > 0:
|
||||
pooled_output = self.encode_proj(pooled_output)
|
||||
|
||||
dpr_encoder_outputs = (sequence_output, pooled_output)
|
||||
|
||||
if output_hidden_states:
|
||||
dpr_encoder_outputs += (hidden_states,)
|
||||
if output_attentions:
|
||||
dpr_encoder_outputs += (outputs[-1],)
|
||||
|
||||
return dpr_encoder_outputs
|
||||
|
||||
@property
|
||||
def embeddings_size(self) -> int:
|
||||
if self.projection_dim > 0:
|
||||
return self.encode_proj.out_features
|
||||
return self.bert_model.config.hidden_size
|
||||
|
||||
def init_weights(self):
|
||||
self.bert_model.init_weights()
|
||||
if self.projection_dim > 0:
|
||||
self.encode_proj.apply(self.bert_model._init_weights)
|
||||
|
||||
|
||||
class DPRSpanPredictor(PreTrainedModel):
|
||||
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: DPRConfig):
|
||||
super().__init__(config)
|
||||
self.encoder = DPREncoder(config)
|
||||
self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
|
||||
self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
|
||||
n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
|
||||
# feed encoder
|
||||
outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# compute logits
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
|
||||
# resize and return
|
||||
return (
|
||||
start_logits.view(n_passages, sequence_length),
|
||||
end_logits.view(n_passages, sequence_length),
|
||||
relevance_logits.view(n_passages),
|
||||
) + outputs[2:]
|
||||
|
||||
def init_weights(self):
|
||||
self.encoder.init_weights()
|
||||
|
||||
|
||||
##################
|
||||
# PreTrainedModel
|
||||
##################
|
||||
|
||||
|
||||
class DPRPretrainedContextEncoder(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "ctx_encoder"
|
||||
|
||||
def init_weights(self):
|
||||
self.ctx_encoder.init_weights()
|
||||
|
||||
|
||||
class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "question_encoder"
|
||||
|
||||
def init_weights(self):
|
||||
self.question_encoder.init_weights()
|
||||
|
||||
|
||||
class DPRPretrainedReader(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "span_predictor"
|
||||
|
||||
def init_weights(self):
|
||||
self.span_predictor.encoder.init_weights()
|
||||
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
|
||||
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
|
||||
|
||||
|
||||
###############
|
||||
# Actual Models
|
||||
###############
|
||||
|
||||
|
||||
DPR_START_DOCSTRING = r"""
|
||||
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.DPRConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
DPR_ENCODERS_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids: (:obj:``torch.LongTensor`` of shape ``(batch_size, sequence_length)``):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
To match pre-training, DPR input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs (for a pair title+text for example):
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences (for a question for example):
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
|
||||
Indices can be obtained using :class:`transformers.DPRTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
attention_mask: (:obj:``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
token_type_ids: (:obj:``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the hidden states tensors of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
"""
|
||||
|
||||
DPR_READER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids: (:obj:``torch.LongTensor`` of shape ``(n_passages, sequence_length)``):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
It has to be a sequence triplet with 1) the question and 2) the passages titles and 3) the passages texts
|
||||
To match pre-training, DPR `input_ids` sequence should be formatted with [CLS] and [SEP] with the format:
|
||||
|
||||
[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>
|
||||
|
||||
DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
|
||||
Indices can be obtained using :class:`transformers.DPRReaderTokenizer`.
|
||||
See :class:`transformers.DPRReaderTokenizer` for more details
|
||||
attention_mask: (:obj:torch.FloatTensor``, of shape ``(n_passages, sequence_length)``, `optional`, defaults to :obj:`None):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(n_passages, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the hidden states tensors of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
|
||||
DPR_START_DOCSTRING,
|
||||
)
|
||||
class DPRContextEncoder(DPRPretrainedContextEncoder):
|
||||
def __init__(self, config: DPRConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.ctx_encoder = DPREncoder(config)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[Tensor] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
|
||||
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
|
||||
The DPR encoder outputs the `pooler_output` that corresponds to the context representation.
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer. This output is to be used to embed contexts for
|
||||
nearest neighbors queries with questions embeddings.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
||||
tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
|
||||
model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
|
||||
input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"]
|
||||
embeddings = model(input_ids)[0] # the embeddings of the given context.
|
||||
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = (
|
||||
torch.ones(input_shape, device=device)
|
||||
if input_ids is None
|
||||
else (input_ids != self.config.pad_token_id)
|
||||
)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
outputs = self.ctx_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
return (pooled_output,) + outputs[2:]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
|
||||
DPR_START_DOCSTRING,
|
||||
)
|
||||
class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
|
||||
def __init__(self, config: DPRConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.question_encoder = DPREncoder(config)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[Tensor] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
|
||||
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
|
||||
The DPR encoder outputs the `pooler_output` that corresponds to the question representation.
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer. This output is to be used to embed questions for
|
||||
nearest neighbors queries with context embeddings.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
|
||||
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
|
||||
model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
|
||||
input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"]
|
||||
embeddings = model(input_ids)[0] # the embeddings of the given question.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = (
|
||||
torch.ones(input_shape, device=device)
|
||||
if input_ids is None
|
||||
else (input_ids != self.config.pad_token_id)
|
||||
)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
outputs = self.question_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
return (pooled_output,) + outputs[2:]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare DPRReader transformer outputting span predictions.", DPR_START_DOCSTRING,
|
||||
)
|
||||
class DPRReader(DPRPretrainedReader):
|
||||
def __init__(self, config: DPRConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.span_predictor = DPRSpanPredictor(config)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(DPR_READER_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[Tensor] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = None,
|
||||
output_hidden_states: bool = None,
|
||||
) -> Tuple[Tensor, ...]:
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
|
||||
input_ids: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``)
|
||||
They correspond to the combined `input_ids` from `(question + context title + context content`).
|
||||
start_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
|
||||
Logits of the start index of the span for each passage.
|
||||
end_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
|
||||
Logits of the end index of the span for each passage.
|
||||
relevance_logits: (:obj:`torch.FloatTensor`` of shape ``(n_passages, )``):
|
||||
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage
|
||||
to answer the question, compared to all the other passages.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import DPRReader, DPRReaderTokenizer
|
||||
tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
|
||||
model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
|
||||
encoded_inputs = tokenizer(
|
||||
questions=["What is love ?"],
|
||||
titles=["Haddaway"],
|
||||
texts=["'What Is Love' is a song recorded by the artist Haddaway"],
|
||||
return_tensors='pt'
|
||||
)
|
||||
outputs = model(**encoded_inputs)
|
||||
start_logits = outputs[0] # The logits of the start of the spans
|
||||
end_logits = outputs[1] # The logits of the end of the spans
|
||||
relevance_logits = outputs[2] # The relevance scores of the passages
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
span_outputs = self.span_predictor(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
start_logits, end_logits, relevance_logits = span_outputs[:3]
|
||||
|
||||
return (start_logits, end_logits, relevance_logits) + span_outputs[3:]
|
||||
384
src/transformers/tokenization_dpr.py
Normal file
384
src/transformers/tokenization_dpr.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for DPR."""
|
||||
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .file_utils import add_end_docstrings, add_start_docstrings
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_utils_base import BatchEncoding, TensorType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
READER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": 512,
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": 512,
|
||||
}
|
||||
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-reader-single-nq-base": 512,
|
||||
}
|
||||
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
|
||||
}
|
||||
READER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class DPRContextEncoderTokenizer(BertTokenizer):
|
||||
r"""
|
||||
Constructs a DPRContextEncoderTokenizer.
|
||||
|
||||
:class:`~transformers.DPRContextEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class DPRContextEncoderTokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Constructs a "Fast" DPRContextEncoderTokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.DPRContextEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class DPRQuestionEncoderTokenizer(BertTokenizer):
|
||||
r"""
|
||||
Constructs a DPRQuestionEncoderTokenizer.
|
||||
|
||||
:class:`~transformers.DPRQuestionEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Constructs a "Fast" DPRQuestionEncoderTokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.DPRQuestionEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
DPRSpanPrediction = collections.namedtuple(
|
||||
"DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
|
||||
)
|
||||
|
||||
DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
|
||||
|
||||
|
||||
CUSTOM_DPR_READER_DOCSTRING = r"""
|
||||
Return a dictionary with the token ids of the input strings and other information to give to :obj:`.decode_best_spans`.
|
||||
It converts the strings of a question and different passages (title + text) in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
The resulting `input_ids` is a matrix of size :obj:`(n_passages, sequence_length)` with the format:
|
||||
|
||||
[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>
|
||||
|
||||
Inputs:
|
||||
questions (:obj:`str`, :obj:`List[str]`):
|
||||
The questions to be encoded.
|
||||
You can specify one question for many passages. In this case, the question will be duplicated like :obj:`[questions] * n_passages`.
|
||||
Otherwise you have to specify as many questions as in :obj:`titles` or :obj:`texts`.
|
||||
titles (:obj:`str`, :obj:`List[str]`):
|
||||
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
|
||||
texts (:obj:`str`, :obj:`List[str]`):
|
||||
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
|
||||
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
||||
Activate and control padding. Accepts the following values:
|
||||
|
||||
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
|
||||
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
|
||||
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
|
||||
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
||||
Activate and control truncation. Accepts the following values:
|
||||
|
||||
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
|
||||
* `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
|
||||
max_length (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
|
||||
Control the length for padding/truncation. Accepts the following values
|
||||
|
||||
* `None` (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated.
|
||||
* `any integer value` (e.g. `42`): Use this specific maximum length value if required by one of the truncation/padding parameters.
|
||||
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
|
||||
PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
|
||||
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`):
|
||||
Whether to return the attention mask. If left to the default, will return the attention mask according
|
||||
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
Return:
|
||||
A Dictionary of shape::
|
||||
|
||||
{
|
||||
input_ids: list[list[int]],
|
||||
attention_mask: list[int] if return_attention_mask is True (default)
|
||||
}
|
||||
|
||||
With the fields:
|
||||
|
||||
- ``input_ids``: list of token ids to be fed to a model
|
||||
- ``attention_mask``: list of indices specifying which tokens should be attended to by the model
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
|
||||
class CustomDPRReaderTokenizerMixin:
|
||||
def __call__(
|
||||
self,
|
||||
questions,
|
||||
titles,
|
||||
texts,
|
||||
padding: Union[bool, str] = True,
|
||||
truncation: Union[bool, str] = True,
|
||||
max_length: Optional[int] = 512,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
titles = titles if not isinstance(titles, str) else [titles]
|
||||
texts = texts if not isinstance(texts, str) else [texts]
|
||||
n_passages = len(titles)
|
||||
questions = questions if not isinstance(questions, str) else [questions] * n_passages
|
||||
assert len(titles) == len(
|
||||
texts
|
||||
), "There should be as many titles than texts but got {} titles and {} texts.".format(len(titles), len(texts))
|
||||
encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
|
||||
encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
|
||||
encoded_inputs = {
|
||||
"input_ids": [
|
||||
(encoded_question_and_title + encoded_text)[:max_length]
|
||||
if max_length is not None and truncation
|
||||
else encoded_question_and_title + encoded_text
|
||||
for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
|
||||
]
|
||||
}
|
||||
if return_attention_mask is not False:
|
||||
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
|
||||
|
||||
def decode_best_spans(
|
||||
self,
|
||||
reader_input: BatchEncoding,
|
||||
reader_output: DPRReaderOutput,
|
||||
num_spans: int = 16,
|
||||
max_answer_length: int = 64,
|
||||
num_spans_per_passage: int = 4,
|
||||
) -> List[DPRSpanPrediction]:
|
||||
"""
|
||||
Get the span predictions for the extractive Q&A model.
|
||||
Outputs: `List` of `DPRReaderOutput` sorted by descending `(relevance_score, span_score)`.
|
||||
Each `DPRReaderOutput` is a `Tuple` with:
|
||||
**span_score**: ``float`` that corresponds to the score given by the reader for this span compared to other spans
|
||||
in the same passage. It corresponds to the sum of the start and end logits of the span.
|
||||
**relevance_score**: ``float`` that corresponds to the score of the each passage to answer the question,
|
||||
compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
|
||||
**doc_id**: ``int``` the id of the passage.
|
||||
**start_index**: ``int`` the start index of the span (inclusive).
|
||||
**end_index**: ``int`` the end index of the span (inclusive).
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import DPRReader, DPRReaderTokenizer
|
||||
tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
|
||||
model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
|
||||
encoded_inputs = tokenizer(
|
||||
questions=["What is love ?"],
|
||||
titles=["Haddaway"],
|
||||
texts=["'What Is Love' is a song recorded by the artist Haddaway"],
|
||||
return_tensors='pt'
|
||||
)
|
||||
outputs = model(**encoded_inputs)
|
||||
predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
|
||||
print(predicted_spans[0].text) # best span
|
||||
|
||||
"""
|
||||
input_ids = reader_input["input_ids"]
|
||||
start_logits, end_logits, relevance_logits = reader_output[:3]
|
||||
n_passages = len(relevance_logits)
|
||||
sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
|
||||
nbest_spans_predictions: List[DPRReaderOutput] = []
|
||||
for doc_id in sorted_docs:
|
||||
sequence_ids = list(input_ids[doc_id])
|
||||
# assuming question & title information is at the beginning of the sequence
|
||||
passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
|
||||
if sequence_ids[-1] == self.pad_token_id:
|
||||
sequence_len = sequence_ids.index(self.pad_token_id)
|
||||
else:
|
||||
sequence_len = len(sequence_ids)
|
||||
|
||||
best_spans = self._get_best_spans(
|
||||
start_logits=start_logits[doc_id][passage_offset:sequence_len],
|
||||
end_logits=end_logits[doc_id][passage_offset:sequence_len],
|
||||
max_answer_length=max_answer_length,
|
||||
top_spans=num_spans_per_passage,
|
||||
)
|
||||
for start_index, end_index in best_spans:
|
||||
start_index += passage_offset
|
||||
end_index += passage_offset
|
||||
nbest_spans_predictions.append(
|
||||
DPRSpanPrediction(
|
||||
span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
|
||||
relevance_score=relevance_logits[doc_id],
|
||||
doc_id=doc_id,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
text=self.decode(sequence_ids[start_index : end_index + 1]),
|
||||
)
|
||||
)
|
||||
if len(nbest_spans_predictions) >= num_spans:
|
||||
break
|
||||
return nbest_spans_predictions[:num_spans]
|
||||
|
||||
def _get_best_spans(
|
||||
self, start_logits: List[int], end_logits: List[int], max_answer_length: int, top_spans: int,
|
||||
) -> List[DPRSpanPrediction]:
|
||||
"""
|
||||
Finds the best answer span for the extractive Q&A model for one passage.
|
||||
It returns the best span by descending `span_score` order and keeping max `top_spans` spans.
|
||||
Spans longer that `max_answer_length` are ignored.
|
||||
"""
|
||||
scores = []
|
||||
for (start_index, start_score) in enumerate(start_logits):
|
||||
for (answer_length, end_score) in enumerate(end_logits[start_index : start_index + max_answer_length]):
|
||||
scores.append(((start_index, start_index + answer_length), start_score + end_score))
|
||||
scores = sorted(scores, key=lambda x: x[1], reverse=True)
|
||||
chosen_span_intervals = []
|
||||
for (start_index, end_index), score in scores:
|
||||
assert start_index <= end_index, "Wrong span indices: [{}:{}]".format(start_index, end_index)
|
||||
length = end_index - start_index + 1
|
||||
assert length <= max_answer_length, "Span is too long: {} > {}".format(length, max_answer_length)
|
||||
if any(
|
||||
[
|
||||
start_index <= prev_start_index <= prev_end_index <= end_index
|
||||
or prev_start_index <= start_index <= end_index <= prev_end_index
|
||||
for (prev_start_index, prev_end_index) in chosen_span_intervals
|
||||
]
|
||||
):
|
||||
continue
|
||||
chosen_span_intervals.append((start_index, end_index))
|
||||
|
||||
if len(chosen_span_intervals) == top_spans:
|
||||
break
|
||||
return chosen_span_intervals
|
||||
|
||||
|
||||
@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
|
||||
class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
|
||||
r"""
|
||||
Constructs a DPRReaderTokenizer.
|
||||
|
||||
:class:`~transformers.DPRReaderTokenizer` is alsmost identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
|
||||
|
||||
@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
|
||||
class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
|
||||
r"""
|
||||
Constructs a DPRReaderTokenizerFast.
|
||||
|
||||
:class:`~transformers.DPRReaderTokenizerFast` is almost identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
@@ -965,7 +965,7 @@ ENCODE_KWARGS_DOCSTRING = r"""
|
||||
>= 7.5 (Volta).
|
||||
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
|
||||
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
|
||||
PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
|
||||
"""
|
||||
|
||||
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
|
||||
@@ -1900,7 +1900,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
|
||||
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
|
||||
PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
|
||||
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Set to ``False`` to avoid printing infos and warnings.
|
||||
"""
|
||||
|
||||
233
tests/test_modeling_dpr.py
Normal file
233
tests/test_modeling_dpr.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers.modeling_dpr import (
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
|
||||
class DPRModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
projection_dim=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_dpr_context_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRContextEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRReader(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
|
||||
result = {
|
||||
"relevance_logits": relevance_logits,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (DPRContextEncoder, DPRQuestionEncoder, DPRReader,) if is_torch_available() else ()
|
||||
|
||||
test_resize_embeddings = False
|
||||
test_missing_keys = False # why?
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DPRModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DPRConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_dpr_context_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_question_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_reader_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DPRContextEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DPRContextEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DPRQuestionEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DPRReader.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
89
tests/test_tokenization_dpr.py
Normal file
89
tests/test_tokenization_dpr.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers.tokenization_dpr import (
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
DPRQuestionEncoderTokenizer,
|
||||
DPRQuestionEncoderTokenizerFast,
|
||||
DPRReaderOutput,
|
||||
DPRReaderTokenizer,
|
||||
DPRReaderTokenizerFast,
|
||||
)
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from .test_tokenization_bert import BertTokenizationTest
|
||||
from .utils import slow
|
||||
|
||||
|
||||
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DPRContextEncoderTokenizer
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return DPRContextEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
||||
class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DPRQuestionEncoderTokenizer
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return DPRQuestionEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
||||
class DPRReaderTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DPRReaderTokenizer
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return DPRReaderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@slow
|
||||
def test_decode_best_spans(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
|
||||
text_1 = tokenizer.encode("question sequence", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("title sequence", add_special_tokens=False)
|
||||
text_3 = tokenizer.encode("text sequence " * 4, add_special_tokens=False)
|
||||
input_ids = [[101] + text_1 + [102] + text_2 + [102] + text_3]
|
||||
reader_input = BatchEncoding({"input_ids": input_ids})
|
||||
|
||||
start_logits = [[0] * len(input_ids[0])]
|
||||
end_logits = [[0] * len(input_ids[0])]
|
||||
relevance_logits = [0]
|
||||
reader_output = DPRReaderOutput(start_logits, end_logits, relevance_logits)
|
||||
|
||||
start_index, end_index = 8, 9
|
||||
start_logits[0][start_index] = 10
|
||||
end_logits[0][end_index] = 10
|
||||
predicted_spans = tokenizer.decode_best_spans(reader_input, reader_output)
|
||||
self.assertEqual(predicted_spans[0].start_index, start_index)
|
||||
self.assertEqual(predicted_spans[0].end_index, end_index)
|
||||
self.assertEqual(predicted_spans[0].doc_id, 0)
|
||||
|
||||
@slow
|
||||
def test_call(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
|
||||
text_1 = tokenizer.encode("question sequence", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("title sequence", add_special_tokens=False)
|
||||
text_3 = tokenizer.encode("text sequence", add_special_tokens=False)
|
||||
expected_input_ids = [101] + text_1 + [102] + text_2 + [102] + text_3
|
||||
encoded_input = tokenizer(questions=["question sequence"], titles=["title sequence"], texts=["text sequence"])
|
||||
self.assertIn("input_ids", encoded_input)
|
||||
self.assertIn("attention_mask", encoded_input)
|
||||
self.assertListEqual(encoded_input["input_ids"][0], expected_input_ids)
|
||||
Reference in New Issue
Block a user