From 78cf7b4ab4de783942383b008be7eb7f65dc541d Mon Sep 17 00:00:00 2001 From: Patrick Lewis Date: Tue, 18 Dec 2018 14:41:30 +0000 Subject: [PATCH] added code to raise value error for bert tokenizer for covert_tokens_to_indices --- pytorch_pretrained_bert/tokenization.py | 44 ++++++++++++++++++------- tests/tokenization_test.py | 22 +++++++++++-- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 5954b78f68..838401565b 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", } +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} VOCAB_NAME = 'vocab.txt' @@ -65,7 +74,8 @@ def whitespace_tokenize(text): class BertTokenizer(object): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" - def __init__(self, vocab_file, do_lower_case=True): + + def __init__(self, vocab_file, do_lower_case=True, max_len=None): if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " @@ -75,6 +85,7 @@ class BertTokenizer(object): [(ids, tok) for tok, ids in self.vocab.items()]) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) def tokenize(self, text): split_tokens = [] @@ -88,6 +99,12 @@ class BertTokenizer(object): ids = [] for token in tokens: ids.append(self.vocab[token]) + if len(ids) > self.max_len: + raise ValueError( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) return ids def convert_ids_to_tokens(self, ids): @@ -126,6 +143,11 @@ class BertTokenizer(object): else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) + if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer @@ -193,7 +215,7 @@ class BasicTokenizer(object): i += 1 return ["".join(x) for x in output] - + def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] @@ -218,17 +240,17 @@ class BasicTokenizer(object): # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True - + return False - + def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index f541a620e8..e1474e938b 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -44,12 +44,30 @@ class TokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + def test_full_tokenizer_raises_error_for_long_sequences(self): + vocab_tokens = [ + "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", + "##ing", "," + ] + with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + vocab_file = vocab_writer.name + + tokenizer = BertTokenizer(vocab_file, max_len=10) + os.remove(vocab_file) + tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time") + indices = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(indices, [0 for _ in range(10)]) + + tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .") + self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens) + def test_chinese(self): tokenizer = BasicTokenizer() - + self.assertListEqual( tokenizer.tokenize(u"ah\u535A\u63A8zz"), - [u"ah", u"\u535A", u"\u63A8", u"zz"]) + [u"ah", u"\u535A", u"\u63A8", u"zz"]) def test_basic_tokenizer_lower(self): tokenizer = BasicTokenizer(do_lower_case=True)