Fix tokenization in SQuAD for RoBERTa, Longformer, BART (#7387)
* fix squad tokenization for roberta & co * change to pure type based check * sort imports
This commit is contained in:
@@ -7,7 +7,10 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
|
from ...tokenization_bart import BartTokenizer
|
||||||
from ...tokenization_bert import whitespace_tokenize
|
from ...tokenization_bert import whitespace_tokenize
|
||||||
|
from ...tokenization_longformer import LongformerTokenizer
|
||||||
|
from ...tokenization_roberta import RobertaTokenizer
|
||||||
from ...tokenization_utils_base import TruncationStrategy
|
from ...tokenization_utils_base import TruncationStrategy
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .utils import DataProcessor
|
from .utils import DataProcessor
|
||||||
@@ -109,7 +112,10 @@ def squad_convert_example_to_features(
|
|||||||
all_doc_tokens = []
|
all_doc_tokens = []
|
||||||
for (i, token) in enumerate(example.doc_tokens):
|
for (i, token) in enumerate(example.doc_tokens):
|
||||||
orig_to_tok_index.append(len(all_doc_tokens))
|
orig_to_tok_index.append(len(all_doc_tokens))
|
||||||
sub_tokens = tokenizer.tokenize(token)
|
if isinstance(tokenizer, (RobertaTokenizer, LongformerTokenizer, BartTokenizer)):
|
||||||
|
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
|
||||||
|
else:
|
||||||
|
sub_tokens = tokenizer.tokenize(token)
|
||||||
for sub_token in sub_tokens:
|
for sub_token in sub_tokens:
|
||||||
tok_to_orig_index.append(i)
|
tok_to_orig_index.append(i)
|
||||||
all_doc_tokens.append(sub_token)
|
all_doc_tokens.append(sub_token)
|
||||||
|
|||||||
Reference in New Issue
Block a user