added tests for OpenAI GPT and Transformer-XL tokenizers
This commit is contained in:
@@ -25,6 +25,7 @@ import os
|
||||
import sys
|
||||
from collections import Counter, OrderedDict
|
||||
from io import open
|
||||
import unicodedata
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -89,8 +90,8 @@ class TransfoXLTokenizer(object):
|
||||
tokenizer.__dict__[key] = value
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
delimiter=None, vocab_file=None):
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
|
||||
delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
|
||||
self.counter = Counter()
|
||||
self.special = special
|
||||
self.min_freq = min_freq
|
||||
@@ -98,6 +99,7 @@ class TransfoXLTokenizer(object):
|
||||
self.lower_case = lower_case
|
||||
self.delimiter = delimiter
|
||||
self.vocab_file = vocab_file
|
||||
self.never_split = never_split
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
@@ -132,7 +134,12 @@ class TransfoXLTokenizer(object):
|
||||
for line in f:
|
||||
symb = line.strip().split()[0]
|
||||
self.add_symbol(symb)
|
||||
self.unk_idx = self.sym2idx['<UNK>']
|
||||
if '<UNK>' in self.sym2idx:
|
||||
self.unk_idx = self.sym2idx['<UNK>']
|
||||
elif '<unk>' in self.sym2idx:
|
||||
self.unk_idx = self.sym2idx['<unk>']
|
||||
else:
|
||||
raise ValueError('No <unkown> token in vocabulary')
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
@@ -198,7 +205,7 @@ class TransfoXLTokenizer(object):
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def get_sym(self, idx):
|
||||
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
|
||||
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
|
||||
return self.idx2sym[idx]
|
||||
|
||||
def get_idx(self, sym):
|
||||
@@ -206,9 +213,16 @@ class TransfoXLTokenizer(object):
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
# print('encounter unk {}'.format(sym))
|
||||
assert '<eos>' not in sym
|
||||
assert hasattr(self, 'unk_idx')
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
# assert '<eos>' not in sym
|
||||
if hasattr(self, 'unk_idx'):
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
# Backward compatibility with pre-trained models
|
||||
elif '<unk>' in self.sym2idx:
|
||||
return self.sym2idx['<unk>']
|
||||
elif '<UNK>' in self.sym2idx:
|
||||
return self.sym2idx['<UNK>']
|
||||
else:
|
||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||
|
||||
def convert_ids_to_tokens(self, indices):
|
||||
"""Converts a sequence of indices in symbols using the vocab."""
|
||||
@@ -231,24 +245,82 @@ class TransfoXLTokenizer(object):
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
# empty delimiter '' will evaluate False
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def whitespace_tokenize(self, text):
|
||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
tokens = text
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
tokens = text.split(self.delimiter)
|
||||
return tokens
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = self._clean_text(line)
|
||||
line = line.strip()
|
||||
|
||||
symbols = self.whitespace_tokenize(line)
|
||||
|
||||
split_symbols = []
|
||||
for symbol in symbols:
|
||||
if self.lower_case and symbol not in self.never_split:
|
||||
symbol = symbol.lower()
|
||||
symbol = self._run_strip_accents(symbol)
|
||||
split_symbols.extend(self._run_split_on_punc(symbol))
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
return ['<S>'] + split_symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return symbols + ['<eos>']
|
||||
return split_symbols + ['<eos>']
|
||||
else:
|
||||
return symbols
|
||||
return split_symbols
|
||||
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
@@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset):
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user