QA pipeline BART compatible (#5496)
* Ensure padding and question cannot have higher probs than context. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add bart the the list of tokenizers adding two <sep> tokens for squad_convert_example_to_feature Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing @patrickvonplaten comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing @patrickvonplaten comments about masking non-context element when generating the answer. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing @sshleifer comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure we mask CLS after handling impossible answers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Mask in the correct vectors ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
@@ -12,6 +12,10 @@ from ...tokenization_bert import whitespace_tokenize
|
||||
from .utils import DataProcessor
|
||||
|
||||
|
||||
# Store the tokenizers which insert 2 separators tokens
|
||||
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart"}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
@@ -123,9 +127,13 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
||||
truncated_query = tokenizer.encode(
|
||||
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
|
||||
)
|
||||
|
||||
# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
|
||||
# in the way they compute mask of added tokens.
|
||||
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
|
||||
sequence_added_tokens = (
|
||||
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
|
||||
if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
|
||||
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
|
||||
else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
)
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
Reference in New Issue
Block a user