Move source code inside a src subdirectory.
This prevents transformers from being importable simply because the CWD
is the root of the git repository, while not being importable from other
directories. That led to inconsistent behavior, especially in examples.
Once you fetch this commit, in your dev environment, you must run:
$ pip uninstall transformers
$ pip install -e .
This commit is contained in:
249
src/transformers/tokenization_ctrl.py
Normal file
249
src/transformers/tokenization_ctrl.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Salesforce and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Salesforce CTRL."""
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from io import open
|
||||
|
||||
import regex as re
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json"},
|
||||
"merges_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt"},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"ctrl": 256,
|
||||
}
|
||||
|
||||
CONTROL_CODES = {
|
||||
"Pregnancy": 168629,
|
||||
"Christianity": 7675,
|
||||
"Explain": 106423,
|
||||
"Fitness": 63440,
|
||||
"Saving": 63163,
|
||||
"Ask": 27171,
|
||||
"Ass": 95985,
|
||||
"Joke": 163509,
|
||||
"Questions": 45622,
|
||||
"Thoughts": 49605,
|
||||
"Retail": 52342,
|
||||
"Feminism": 164338,
|
||||
"Writing": 11992,
|
||||
"Atheism": 192263,
|
||||
"Netflix": 48616,
|
||||
"Computing": 39639,
|
||||
"Opinion": 43213,
|
||||
"Alone": 44967,
|
||||
"Funny": 58917,
|
||||
"Gaming": 40358,
|
||||
"Human": 4088,
|
||||
"India": 1331,
|
||||
"Joker": 77138,
|
||||
"Diet": 36206,
|
||||
"Legal": 11859,
|
||||
"Norman": 4939,
|
||||
"Tip": 72689,
|
||||
"Weight": 52343,
|
||||
"Movies": 46273,
|
||||
"Running": 23425,
|
||||
"Science": 2090,
|
||||
"Horror": 37793,
|
||||
"Confession": 60572,
|
||||
"Finance": 12250,
|
||||
"Politics": 16360,
|
||||
"Scary": 191985,
|
||||
"Support": 12654,
|
||||
"Technologies": 32516,
|
||||
"Teenage": 66160,
|
||||
"Event": 32769,
|
||||
"Learned": 67460,
|
||||
"Notion": 182770,
|
||||
"Wikipedia": 37583,
|
||||
"Books": 6665,
|
||||
"Extract": 76050,
|
||||
"Confessions": 102701,
|
||||
"Conspiracy": 75932,
|
||||
"Links": 63674,
|
||||
"Narcissus": 150425,
|
||||
"Relationship": 54766,
|
||||
"Relationships": 134796,
|
||||
"Reviews": 41671,
|
||||
"News": 4256,
|
||||
"Translation": 26820,
|
||||
"multilingual": 128406,
|
||||
}
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
|
||||
pairs = set(pairs)
|
||||
return pairs
|
||||
|
||||
|
||||
class CTRLTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
CTRL BPE tokenizer. Peculiarities:
|
||||
- Byte-Pair-Encoding
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
control_codes = CONTROL_CODES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
||||
self.max_len_single_sentence = (
|
||||
self.max_len
|
||||
) # no default special tokens - you can update this value if you add special tokens
|
||||
self.max_len_sentences_pair = (
|
||||
self.max_len
|
||||
) # no default special tokens - you can update this value if you add special tokens
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
merges = merges_handle.read().split("\n")[1:-1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = "@@ ".join(word)
|
||||
word = word[:-4]
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string.
|
||||
"""
|
||||
split_tokens = []
|
||||
|
||||
words = re.findall(r"\S+\n?", text)
|
||||
|
||||
for token in words:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(" ")])
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = " ".join(tokens).replace("@@ ", "").strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
# def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
# filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
||||
# tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
||||
# tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
||||
# return ''.join(tokens_generated_so_far)
|
||||
Reference in New Issue
Block a user