update tokenizer
This commit is contained in:
@@ -16,21 +16,13 @@
|
|||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import regex as re
|
import regex as re
|
||||||
from io import open
|
from io import open
|
||||||
import pdb
|
|
||||||
|
|
||||||
try:
|
from .tokenization_bert import BasicTokenizer
|
||||||
from functools import lru_cache
|
|
||||||
except ImportError:
|
|
||||||
# Just a dummy decorator to get the checks to run on python2
|
|
||||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
|
||||||
def lru_cache():
|
|
||||||
return lambda func: func
|
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
@@ -53,49 +45,47 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'ctrl': 1280,
|
'ctrl': 256,
|
||||||
}
|
}
|
||||||
|
|
||||||
@lru_cache()
|
def text_standardize(text):
|
||||||
def bytes_to_unicode():
|
|
||||||
"""
|
"""
|
||||||
Returns list of utf-8 byte and a mapping to unicode strings.
|
fixes some issues the spacy tokenizer had on books corpus
|
||||||
We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
|
also does some whitespace standardization
|
||||||
|
"""
|
||||||
|
text = text.replace('—', '-')
|
||||||
|
text = text.replace('–', '-')
|
||||||
|
text = text.replace('―', '-')
|
||||||
|
text = text.replace('…', '...')
|
||||||
|
text = text.replace('´', "'")
|
||||||
|
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
||||||
|
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
||||||
|
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
The reversible bpe codes work on unicode strings.
|
|
||||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
||||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
||||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
||||||
"""
|
|
||||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
|
||||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
|
||||||
cs = bs[:]
|
|
||||||
n = 0
|
|
||||||
for b in range(2**8):
|
|
||||||
if b not in bs:
|
|
||||||
bs.append(b)
|
|
||||||
cs.append(2**8+n)
|
|
||||||
n += 1
|
|
||||||
cs = [_chr(n) for n in cs]
|
|
||||||
return dict(zip(bs, cs))
|
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""Return set of symbol pairs in a word.
|
"""Return set of symbol pairs in a word.
|
||||||
|
|
||||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
"""
|
"""
|
||||||
pairs= []
|
# pairs = []
|
||||||
|
# prev_char = word[0]
|
||||||
|
# for i, char in enumerate(word[1:]):
|
||||||
|
# #_i = i + 1
|
||||||
|
# #if word[_i+1:] == tuple('</w>'):
|
||||||
|
# # pairs.append((prev_char, char+'</w>'))
|
||||||
|
# # break
|
||||||
|
# #else:
|
||||||
|
# if True:
|
||||||
|
# pairs.append((prev_char, char))
|
||||||
|
# prev_char = char
|
||||||
|
|
||||||
|
pairs = set()
|
||||||
prev_char = word[0]
|
prev_char = word[0]
|
||||||
for i, char in enumerate(word[1:]):
|
for char in word[1:]:
|
||||||
#_i = i + 1
|
pairs.add((prev_char, char))
|
||||||
#if word[_i+1:] == tuple('</w>'):
|
prev_char = char
|
||||||
# pairs.append((prev_char, char+'</w>'))
|
|
||||||
# break
|
|
||||||
#else:
|
|
||||||
if True:
|
|
||||||
pairs.append((prev_char, char))
|
|
||||||
prev_char = char
|
|
||||||
|
|
||||||
pairs = set(pairs)
|
pairs = set(pairs)
|
||||||
return pairs
|
return pairs
|
||||||
@@ -113,24 +103,28 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<unk>",
|
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||||
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
|
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
||||||
super(CTRLTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
|
||||||
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
|
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
|
||||||
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
|
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
|
||||||
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
try:
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
import ftfy
|
||||||
self.errors = errors # how to handle errors in decoding
|
from spacy.lang.en import English
|
||||||
self.byte_encoder = bytes_to_unicode()
|
_nlp = English()
|
||||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
||||||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
self.fix_text = ftfy.fix_text
|
||||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
except ImportError:
|
||||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
self.cache = {}
|
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||||
|
self.fix_text = None
|
||||||
|
|
||||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
@@ -179,23 +173,27 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def _tokenize(self, text, add_prefix_space=False):
|
def _tokenize(self, text):
|
||||||
""" Tokenize a string.
|
""" Tokenize a string.
|
||||||
Args:
|
|
||||||
- add_prefix_space (boolean, default False):
|
|
||||||
Begin the sentence with at least one space toto get invariance to word order in CTRL (and RoBERTa) tokenizers.
|
|
||||||
"""
|
"""
|
||||||
if add_prefix_space:
|
split_tokens = []
|
||||||
text = ' ' + text
|
if self.fix_text is None:
|
||||||
|
# Using BERT's BasicTokenizer
|
||||||
bpe_tokens = []
|
text = self.nlp.tokenize(text)
|
||||||
for token in text.split():
|
for token in text:
|
||||||
if sys.version_info[0] == 2:
|
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||||
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
else:
|
||||||
else:
|
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
for token in text:
|
||||||
return bpe_tokens
|
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||||
|
# for token in text.split():
|
||||||
|
# if sys.version_info[0] == 2:
|
||||||
|
# token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
||||||
|
# else:
|
||||||
|
# token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
||||||
|
# bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return split_tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
@@ -203,13 +201,12 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
return self.decoder.get(index)
|
return self.decoder.get(index, self.unk_token)
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
""" Converts a sequence of tokens (string) in a single string. """
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
text = ''.join(tokens)
|
out_string = ''.join(tokens).replace('@@', ' ').strip()
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
return out_string
|
||||||
return text
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
@@ -235,10 +232,8 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return vocab_file, merge_file
|
return vocab_file, merge_file
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
# def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
# filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
||||||
tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
# tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
||||||
tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
# tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
||||||
return ''.join(tokens_generated_so_far)
|
# return ''.join(tokens_generated_so_far)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user