QA pipeline BART compatible (#5496)

* Ensure padding and question cannot have higher probs than context.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Add bart the the list of tokenizers adding two <sep> tokens for squad_convert_example_to_feature

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Format.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Addressing @patrickvonplaten comments.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Addressing @patrickvonplaten comments about masking non-context element when generating the answer.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Addressing @sshleifer comments.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make sure we mask CLS after handling impossible answers

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Mask in the correct vectors ...

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Funtowicz Morgan
2020-07-09 15:11:40 +02:00
committed by GitHub
parent fa5423b169
commit 3bd55199cd
2 changed files with 20 additions and 8 deletions

View File

@@ -12,6 +12,10 @@ from ...tokenization_bert import whitespace_tokenize
from .utils import DataProcessor
# Store the tokenizers which insert 2 separators tokens
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart"}
if is_torch_available():
import torch
from torch.utils.data import TensorDataset
@@ -123,9 +127,13 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
truncated_query = tokenizer.encode(
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
)
# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
# in the way they compute mask of added tokens.
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
sequence_added_tokens = (
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
else tokenizer.max_len - tokenizer.max_len_single_sentence
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair