Warn instead of raising in BERT and GPT-2 tokenizers as well, to allow for pre-caching of tokens

This commit is contained in:
Catalin Voss
2019-03-05 12:31:45 -08:00
parent e99bc87e4d
commit 4a49c22584
2 changed files with 2 additions and 2 deletions

View File

@@ -101,7 +101,7 @@ class BertTokenizer(object):
for token in tokens: for token in tokens:
ids.append(self.vocab[token]) ids.append(self.vocab[token])
if len(ids) > self.max_len: if len(ids) > self.max_len:
raise ValueError( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" " sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)

View File

@@ -193,7 +193,7 @@ class GPT2Tokenizer(object):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len: if len(bpe_tokens) > self.max_len:
raise ValueError( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this" " sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)