Add custom tokenizer for zh and ja
This commit is contained in:
@@ -23,7 +23,11 @@ import re
|
|||||||
import unicodedata
|
import unicodedata
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
import Mykytea
|
||||||
import sacremoses as sm
|
import sacremoses as sm
|
||||||
|
from nltk.tokenize.stanford_segmenter import StanfordSegmenter
|
||||||
|
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_bert import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
@@ -83,21 +87,6 @@ def get_pairs(word):
|
|||||||
prev_char = char
|
prev_char = char
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
def text_standardize(text):
|
|
||||||
"""
|
|
||||||
fixes some issues the spacy tokenizer had on books corpus
|
|
||||||
also does some whitespace standardization
|
|
||||||
"""
|
|
||||||
text = text.replace('—', '-')
|
|
||||||
text = text.replace('–', '-')
|
|
||||||
text = text.replace('―', '-')
|
|
||||||
text = text.replace('…', '...')
|
|
||||||
text = text.replace('´', "'")
|
|
||||||
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
|
||||||
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
|
||||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def lowercase_and_remove_accent(text):
|
def lowercase_and_remove_accent(text):
|
||||||
"""
|
"""
|
||||||
@@ -120,7 +109,7 @@ def replace_unicode_punct(text):
|
|||||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||||||
'''
|
'''
|
||||||
text = text.replace(',', ',')
|
text = text.replace(',', ',')
|
||||||
text = text.replace('。 *', '. ')
|
text = re.sub(r'。\s*', '. ', text)
|
||||||
text = text.replace('、', ',')
|
text = text.replace('、', ',')
|
||||||
text = text.replace('”', '"')
|
text = text.replace('”', '"')
|
||||||
text = text.replace('“', '"')
|
text = text.replace('“', '"')
|
||||||
@@ -220,6 +209,8 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||||
self.do_lowercase_and_remove_accent = True
|
self.do_lowercase_and_remove_accent = True
|
||||||
|
self.ja_word_tokenizer = None
|
||||||
|
self.zh_word_tokenizer = None
|
||||||
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
@@ -250,6 +241,33 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
text = remove_non_printing_char(text)
|
text = remove_non_printing_char(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def ja_tokenize(self, text):
|
||||||
|
if self.ja_word_tokenizer is None:
|
||||||
|
try:
|
||||||
|
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) with the following steps")
|
||||||
|
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
||||||
|
logger.error("2. autoreconf -i")
|
||||||
|
logger.error("3. ./configure --prefix=$HOME/local")
|
||||||
|
logger.error("4. make && make install")
|
||||||
|
import sys; sys.exit()
|
||||||
|
return list(self.ja_word_tokenizer.getWS(text))
|
||||||
|
|
||||||
|
def zh_tokenize(self, text):
|
||||||
|
if self.zh_word_tokenizer is None:
|
||||||
|
try:
|
||||||
|
self.zh_word_tokenizer = StanfordSegmenter()
|
||||||
|
self.zh_word_tokenizer.default_config('zh')
|
||||||
|
except LookupError:
|
||||||
|
logger.error("Make sure you download stanford-segmenter (https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip) with the following steps")
|
||||||
|
logger.error("1. wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip -O /path/to/stanford-segmenter-2018-10-16.zip")
|
||||||
|
logger.error("2. cd /path/to && unzip stanford-segmenter-2018-10-16.zip")
|
||||||
|
logger.error("3. cd stanford-segmenter-2018-10-16 && cp stanford-segmenter-3.9.2.jar stanford-segmenter.jar")
|
||||||
|
logger.error("4. set env variable STANFORD_SEGMENTER=/path/to/stanford-segmenter-2018-10-16")
|
||||||
|
import sys; sys.exit()
|
||||||
|
return self.zh_word_tokenizer.segment(text)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.encoder)
|
return len(self.encoder)
|
||||||
@@ -299,7 +317,6 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _tokenize(self, text, lang='en'):
|
def _tokenize(self, text, lang='en'):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
split_tokens = []
|
|
||||||
if self.do_lowercase_and_remove_accent:
|
if self.do_lowercase_and_remove_accent:
|
||||||
text = lowercase_and_remove_accent(text)
|
text = lowercase_and_remove_accent(text)
|
||||||
if lang not in self.lang_with_custom_tokenizer:
|
if lang not in self.lang_with_custom_tokenizer:
|
||||||
@@ -308,10 +325,24 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
if lang == 'ro':
|
if lang == 'ro':
|
||||||
text = romanian_preprocessing(text)
|
text = romanian_preprocessing(text)
|
||||||
text = self.moses_tokenize(text, lang=lang)
|
text = self.moses_tokenize(text, lang=lang)
|
||||||
|
elif lang == 'th':
|
||||||
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
|
text = th_word_tokenize(text)
|
||||||
|
elif lang == 'zh':
|
||||||
|
# text = self.zh_tokenize(text)
|
||||||
|
text = ' '.join(jieba.cut(text))
|
||||||
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
|
text = text.split()
|
||||||
|
elif lang == 'ja':
|
||||||
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
|
text = self.ja_tokenize(text)
|
||||||
|
else:
|
||||||
|
raise ValueError('It should not reach here')
|
||||||
|
|
||||||
|
split_tokens = []
|
||||||
for token in text:
|
for token in text:
|
||||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
|
|||||||
@@ -12,3 +12,7 @@ regex
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
# For XLM
|
# For XLM
|
||||||
sacremoses
|
sacremoses
|
||||||
|
pythainlp
|
||||||
|
kytea
|
||||||
|
nltk
|
||||||
|
jieba
|
||||||
6
setup.py
6
setup.py
@@ -56,7 +56,11 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'regex',
|
'regex',
|
||||||
'sentencepiece',
|
'sentencepiece',
|
||||||
'sacremoses'],
|
'sacremoses',
|
||||||
|
'pythainlp',
|
||||||
|
'kytea',
|
||||||
|
'nltk',
|
||||||
|
'jieba'],
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
"pytorch_transformers=pytorch_transformers.__main__:main",
|
"pytorch_transformers=pytorch_transformers.__main__:main",
|
||||||
|
|||||||
Reference in New Issue
Block a user