Merge pull request #127 from patrick-s-h-lewis/tokenizer-error-on-long-seqs
raises value error for bert tokenizer for long sequences
This commit is contained in:
@@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
|||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||||
}
|
}
|
||||||
|
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||||
|
'bert-base-uncased': 512,
|
||||||
|
'bert-large-uncased': 512,
|
||||||
|
'bert-base-cased': 512,
|
||||||
|
'bert-large-cased': 512,
|
||||||
|
'bert-base-multilingual-uncased': 512,
|
||||||
|
'bert-base-multilingual-cased': 512,
|
||||||
|
'bert-base-chinese': 512,
|
||||||
|
}
|
||||||
VOCAB_NAME = 'vocab.txt'
|
VOCAB_NAME = 'vocab.txt'
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +74,8 @@ def whitespace_tokenize(text):
|
|||||||
|
|
||||||
class BertTokenizer(object):
|
class BertTokenizer(object):
|
||||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||||
def __init__(self, vocab_file, do_lower_case=True):
|
|
||||||
|
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
@@ -75,6 +85,7 @@ class BertTokenizer(object):
|
|||||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
@@ -88,6 +99,12 @@ class BertTokenizer(object):
|
|||||||
ids = []
|
ids = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
ids.append(self.vocab[token])
|
ids.append(self.vocab[token])
|
||||||
|
if len(ids) > self.max_len:
|
||||||
|
raise ValueError(
|
||||||
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
|
" sequence length for this BERT model ({} > {}). Running this"
|
||||||
|
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||||
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
def convert_ids_to_tokens(self, ids):
|
||||||
@@ -126,6 +143,11 @@ class BertTokenizer(object):
|
|||||||
else:
|
else:
|
||||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
vocab_file, resolved_vocab_file))
|
vocab_file, resolved_vocab_file))
|
||||||
|
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||||
|
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||||
|
# than the number of positional embeddings
|
||||||
|
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||||
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -44,6 +44,24 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
def test_full_tokenizer_raises_error_for_long_sequences(self):
|
||||||
|
vocab_tokens = [
|
||||||
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
|
"##ing", ","
|
||||||
|
]
|
||||||
|
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer(vocab_file, max_len=10)
|
||||||
|
os.remove(vocab_file)
|
||||||
|
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
|
||||||
|
indices = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(indices, [0 for _ in range(10)])
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
|
||||||
|
self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user