adding conversion script adding first draft of modeling & tokenization adding placeholder for test files bunch of changes registering the tokenizer/model/etc tests change link; something is very VERY wrong here weird end-of-word thingy going on i think the tokenization works now ; wrote the unit tests overall structure works;load w next the monster is alive! works after some cleanup as well adding emacs autosave to gitignore currently only supporting the 48 layer one; seems to infer fine on my macbook cleanup fixing some documentation fixing some documentation tests passing? now works on CUDA also adding greedy? adding greedy sampling works well
245 lines
9.6 KiB
Python
245 lines
9.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Salesforce and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tokenization classes for Salesforce CTRL."""
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
|
|
import sys
|
|
import json
|
|
import logging
|
|
import os
|
|
import regex as re
|
|
from io import open
|
|
import pdb
|
|
|
|
try:
|
|
from functools import lru_cache
|
|
except ImportError:
|
|
# Just a dummy decorator to get the checks to run on python2
|
|
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
|
def lru_cache():
|
|
return lambda func: func
|
|
|
|
from .tokenization_utils import PreTrainedTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VOCAB_FILES_NAMES = {
|
|
'vocab_file': 'vocab.json',
|
|
'merges_file': 'merges.txt',
|
|
}
|
|
|
|
PRETRAINED_VOCAB_FILES_MAP = {
|
|
'vocab_file':
|
|
{
|
|
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
|
|
},
|
|
'merges_file':
|
|
{
|
|
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
|
|
},
|
|
}
|
|
|
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|
'ctrl': 1280,
|
|
}
|
|
|
|
@lru_cache()
|
|
def bytes_to_unicode():
|
|
"""
|
|
Returns list of utf-8 byte and a mapping to unicode strings.
|
|
We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
|
|
The reversible bpe codes work on unicode strings.
|
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
"""
|
|
_chr = unichr if sys.version_info[0] == 2 else chr
|
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
|
cs = bs[:]
|
|
n = 0
|
|
for b in range(2**8):
|
|
if b not in bs:
|
|
bs.append(b)
|
|
cs.append(2**8+n)
|
|
n += 1
|
|
cs = [_chr(n) for n in cs]
|
|
return dict(zip(bs, cs))
|
|
|
|
def get_pairs(word):
|
|
"""Return set of symbol pairs in a word.
|
|
|
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
"""
|
|
pairs= []
|
|
prev_char = word[0]
|
|
for i, char in enumerate(word[1:]):
|
|
#_i = i + 1
|
|
#if word[_i+1:] == tuple('</w>'):
|
|
# pairs.append((prev_char, char+'</w>'))
|
|
# break
|
|
#else:
|
|
if True:
|
|
pairs.append((prev_char, char))
|
|
prev_char = char
|
|
|
|
pairs = set(pairs)
|
|
return pairs
|
|
|
|
class CTRLTokenizer(PreTrainedTokenizer):
|
|
"""
|
|
CTRL BPE tokenizer. Peculiarities:
|
|
- Byte-level Byte-Pair-Encoding
|
|
- Requires a space to start the input string => the encoding methods should be called with the
|
|
``add_prefix_space`` flag set to ``True``.
|
|
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
|
|
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
|
|
"""
|
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
|
|
|
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<unk>",
|
|
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
|
|
super(CTRLTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
|
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
|
|
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
|
|
|
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
self.errors = errors # how to handle errors in decoding
|
|
self.byte_encoder = bytes_to_unicode()
|
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
|
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
|
self.cache = {}
|
|
|
|
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.encoder)
|
|
|
|
def bpe(self, token):
|
|
if token in self.cache:
|
|
return self.cache[token]
|
|
word = tuple(token)
|
|
word = tuple(list(word[:-1]) + [word[-1]+'</w>'])
|
|
pairs = get_pairs(word)
|
|
|
|
if not pairs:
|
|
return token
|
|
|
|
while True:
|
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
if bigram not in self.bpe_ranks:
|
|
break
|
|
first, second = bigram
|
|
new_word = []
|
|
i = 0
|
|
while i < len(word):
|
|
try:
|
|
j = word.index(first, i)
|
|
new_word.extend(word[i:j])
|
|
i = j
|
|
except:
|
|
new_word.extend(word[i:])
|
|
break
|
|
|
|
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
|
new_word.append(first+second)
|
|
i += 2
|
|
else:
|
|
new_word.append(word[i])
|
|
i += 1
|
|
new_word = tuple(new_word)
|
|
word = new_word
|
|
if len(word) == 1:
|
|
break
|
|
else:
|
|
pairs = get_pairs(word)
|
|
word = '@@ '.join(word)
|
|
word = word[:-4]
|
|
self.cache[token] = word
|
|
return word
|
|
|
|
def _tokenize(self, text, add_prefix_space=False):
|
|
""" Tokenize a string.
|
|
Args:
|
|
- add_prefix_space (boolean, default False):
|
|
Begin the sentence with at least one space toto get invariance to word order in CTRL (and RoBERTa) tokenizers.
|
|
"""
|
|
if add_prefix_space:
|
|
text = ' ' + text
|
|
|
|
bpe_tokens = []
|
|
for token in text.split():
|
|
if sys.version_info[0] == 2:
|
|
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
|
else:
|
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
|
return bpe_tokens
|
|
|
|
def _convert_token_to_id(self, token):
|
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
|
|
|
def _convert_id_to_token(self, index):
|
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
|
return self.decoder.get(index)
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
""" Converts a sequence of tokens (string) in a single string. """
|
|
text = ''.join(tokens)
|
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
|
return text
|
|
|
|
def save_vocabulary(self, save_directory):
|
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
|
if not os.path.isdir(save_directory):
|
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
|
return
|
|
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
|
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
|
|
|
with open(vocab_file, 'w', encoding='utf-8') as f:
|
|
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
|
|
|
index = 0
|
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
|
writer.write(u'#version: 0.2\n')
|
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
|
if index != token_index:
|
|
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
|
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
|
index = token_index
|
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
|
index += 1
|
|
|
|
return vocab_file, merge_file
|
|
|
|
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
|
filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
|
tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
|
tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
|
return ''.join(tokens_generated_so_far)
|
|
|
|
|