From b450a7faf28cadcc572a4b55550337ab58a8e48c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 18 Feb 2019 11:27:18 +0100 Subject: [PATCH] clean up tokenization - fix python 2 tests --- pytorch_pretrained_bert/tokenization_gpt2.py | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index 1f62d63200..f76651402d 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -20,14 +20,19 @@ import json import logging import os import regex as re -import sys from io import open -from functools import lru_cache -from tqdm import tqdm +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(func): + def func_wrapper(*inputs, **args): + return func(inputs, args) + return func_wrapper from .file_utils import cached_path -from .tokenization import BasicTokenizer logger = logging.getLogger(__name__) @@ -125,7 +130,8 @@ class GPT2Tokenizer(object): tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) return tokenizer - def __init__(self, vocab_file, merges_file, errors='replace'): + def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -188,6 +194,12 @@ class GPT2Tokenizer(object): for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + if len(bpe_tokens) > self.max_len: + raise ValueError( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) + ) return bpe_tokens def decode(self, tokens):