add serialization semantics to tokenizers - fix transfo-xl tokenizer
This commit is contained in:
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
else:
|
||||
vocab_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
|
||||
else:
|
||||
raise ValueError('No <unkown> token in vocabulary')
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
index = 0
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
@@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def whitespace_tokenize(self, text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.delimiter == '':
|
||||
tokens = text
|
||||
else:
|
||||
tokens = text.split(self.delimiter)
|
||||
return tokens
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = self._clean_text(line)
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
symbols = self.whitespace_tokenize(line)
|
||||
|
||||
split_symbols = []
|
||||
for symbol in symbols:
|
||||
if self.lower_case and symbol not in self.never_split:
|
||||
symbol = symbol.lower()
|
||||
symbol = self._run_strip_accents(symbol)
|
||||
split_symbols.extend(self._run_split_on_punc(symbol))
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + split_symbols + ['<S>']
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return split_symbols + ['<eos>']
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return split_symbols
|
||||
return symbols
|
||||
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
@@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user