OpenAI GPT Tokenizer can fallback on using BERT BasicTokenizer
This commit is contained in:
@@ -26,6 +26,7 @@ from io import open
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
from .tokenization import BasicTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object):
|
|||||||
"""
|
"""
|
||||||
BPE tokenizer. Peculiarities:
|
BPE tokenizer. Peculiarities:
|
||||||
- lower case all inputs
|
- lower case all inputs
|
||||||
- uses SpaCy tokenizer
|
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
||||||
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
|
- argument special_tokens and function set_special_tokens:
|
||||||
|
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
@@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object):
|
|||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
import spacy
|
import spacy
|
||||||
|
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||||
|
self.fix_text = ftfy.fix_text
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
|
self.nlp = BasicTokenizer(do_lower_case=True,
|
||||||
|
never_split=special_tokens if special_tokens is not None else [])
|
||||||
|
self.fix_text = None
|
||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
|
||||||
self.fix_text = ftfy.fix_text
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
@@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object):
|
|||||||
return
|
return
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||||
|
if self.fix_text is None:
|
||||||
|
# Using BERT's BasicTokenizer: we can update the tokenizer
|
||||||
|
self.nlp.never_split = special_tokens
|
||||||
logger.info("Special tokens {}".format(self.special_tokens))
|
logger.info("Special tokens {}".format(self.special_tokens))
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
@@ -198,9 +206,16 @@ class OpenAIGPTTokenizer(object):
|
|||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
if self.fix_text is None:
|
||||||
for token in text:
|
# Using BERT's BasicTokenizer
|
||||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
text = self.nlp.tokenize(text)
|
||||||
|
for token in text:
|
||||||
|
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||||
|
else:
|
||||||
|
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||||
|
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||||
|
for token in text:
|
||||||
|
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
@@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object):
|
|||||||
if len(ids) > self.max_len:
|
if len(ids) > self.max_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
" sequence length for this BERT model ({} > {}). Running this"
|
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||||
)
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user