Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in tokenize
This commit is contained in:
@@ -20,8 +20,11 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import unicodedata
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
import sacremoses as sm
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_bert import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
|
|
||||||
@@ -95,6 +98,93 @@ def text_standardize(text):
|
|||||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def lowercase_and_remove_accent(text):
|
||||||
|
"""
|
||||||
|
Lowercase and strips accents from a piece of text based on
|
||||||
|
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
|
||||||
|
"""
|
||||||
|
text = text.lower()
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def replace_unicode_punct(text):
|
||||||
|
'''
|
||||||
|
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||||||
|
'''
|
||||||
|
text = text.replace(',', ',')
|
||||||
|
text = text.replace('。 *', '. ')
|
||||||
|
text = text.replace('、', ',')
|
||||||
|
text = text.replace('”', '"')
|
||||||
|
text = text.replace('“', '"')
|
||||||
|
text = text.replace('∶', ':')
|
||||||
|
text = text.replace(':', ':')
|
||||||
|
text = text.replace('?', '?')
|
||||||
|
text = text.replace('《', '"')
|
||||||
|
text = text.replace('》', '"')
|
||||||
|
text = text.replace(')', ')')
|
||||||
|
text = text.replace('!', '!')
|
||||||
|
text = text.replace('(', '(')
|
||||||
|
text = text.replace(';', ';')
|
||||||
|
text = text.replace('1', '"')
|
||||||
|
text = text.replace('」', '"')
|
||||||
|
text = text.replace('「', '"')
|
||||||
|
text = text.replace('0', '0')
|
||||||
|
text = text.replace('3', '3')
|
||||||
|
text = text.replace('2', '2')
|
||||||
|
text = text.replace('5', '5')
|
||||||
|
text = text.replace('6', '6')
|
||||||
|
text = text.replace('9', '9')
|
||||||
|
text = text.replace('7', '7')
|
||||||
|
text = text.replace('8', '8')
|
||||||
|
text = text.replace('4', '4')
|
||||||
|
text = re.sub(r'.\s*', '. ', text)
|
||||||
|
text = text.replace('~', '~')
|
||||||
|
text = text.replace('’', '\'')
|
||||||
|
text = text.replace('…', '...')
|
||||||
|
text = text.replace('━', '-')
|
||||||
|
text = text.replace('〈', '<')
|
||||||
|
text = text.replace('〉', '>')
|
||||||
|
text = text.replace('【', '[')
|
||||||
|
text = text.replace('】', ']')
|
||||||
|
text = text.replace('%', '%')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def remove_non_printing_char(text):
|
||||||
|
'''
|
||||||
|
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
||||||
|
'''
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith('C'):
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def romanian_preprocessing(text):
|
||||||
|
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
|
||||||
|
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
|
||||||
|
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
|
||||||
|
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
|
||||||
|
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
|
||||||
|
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
|
||||||
|
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
|
||||||
|
text = text.replace("\u0102", "A").replace("\u0103", "a")
|
||||||
|
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
|
||||||
|
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class XLMTokenizer(PreTrainedTokenizer):
|
class XLMTokenizer(PreTrainedTokenizer):
|
||||||
"""
|
"""
|
||||||
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
|
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
|
||||||
@@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
cls_token=cls_token, mask_token=mask_token,
|
cls_token=cls_token, mask_token=mask_token,
|
||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
try:
|
|
||||||
import ftfy
|
# cache of sm.MosesPunctNormalizer instance
|
||||||
from spacy.lang.en import English
|
self.cache_moses_punct_normalizer = dict()
|
||||||
_nlp = English()
|
# cache of sm.MosesTokenizer instance
|
||||||
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
self.cache_moses_tokenizer = dict()
|
||||||
self.fix_text = ftfy.fix_text
|
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||||
except ImportError:
|
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
self.do_lowercase_and_remove_accent = True
|
||||||
self.nlp = BasicTokenizer(do_lower_case=True)
|
|
||||||
self.fix_text = None
|
|
||||||
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
@@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
def moses_punct_norm(self, text, lang):
|
||||||
|
if lang not in self.cache_moses_punct_normalizer:
|
||||||
|
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
|
||||||
|
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
||||||
|
else:
|
||||||
|
punct_normalizer = self.cache_moses_punct_normalizer[lang]
|
||||||
|
return punct_normalizer.normalize(text)
|
||||||
|
|
||||||
|
def moses_tokenize(self, text, lang):
|
||||||
|
if lang not in self.cache_moses_tokenizer:
|
||||||
|
moses_tokenizer = sm.MosesTokenizer(lang=lang)
|
||||||
|
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
||||||
|
else:
|
||||||
|
moses_tokenizer = self.cache_moses_tokenizer[lang]
|
||||||
|
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
||||||
|
|
||||||
|
def moses_pipeline(self, text, lang):
|
||||||
|
text = replace_unicode_punct(text)
|
||||||
|
text = self.moses_punct_norm(text, lang)
|
||||||
|
text = remove_non_printing_char(text)
|
||||||
|
return text
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.encoder)
|
return len(self.encoder)
|
||||||
@@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, lang='en'):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.fix_text is None:
|
if self.do_lowercase_and_remove_accent:
|
||||||
# Using BERT's BasicTokenizer
|
text = lowercase_and_remove_accent(text)
|
||||||
text = self.nlp.tokenize(text)
|
if lang not in self.lang_with_custom_tokenizer:
|
||||||
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
|
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
|
||||||
|
if lang == 'ro':
|
||||||
|
text = romanian_preprocessing(text)
|
||||||
|
text = self.moses_tokenize(text, lang=lang)
|
||||||
for token in text:
|
for token in text:
|
||||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||||
else:
|
else:
|
||||||
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
raise ValueError
|
||||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
|
||||||
for token in text:
|
|
||||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
|
|||||||
@@ -10,3 +10,5 @@ requests
|
|||||||
regex
|
regex
|
||||||
# For XLNet
|
# For XLNet
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
# For XLM
|
||||||
|
sacremoses
|
||||||
3
setup.py
3
setup.py
@@ -55,7 +55,8 @@ setup(
|
|||||||
'requests',
|
'requests',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'regex',
|
'regex',
|
||||||
'sentencepiece'],
|
'sentencepiece',
|
||||||
|
'sacremoses'],
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
"pytorch_transformers=pytorch_transformers.__main__:main",
|
"pytorch_transformers=pytorch_transformers.__main__:main",
|
||||||
|
|||||||
Reference in New Issue
Block a user