Transformer-XL: Improved tokenization with sacremoses (#6322)
* Improved tokenization with sacremoses * The TransfoXLTokenizer is now using sacremoses for tokenization * Added tokenization of comma-separated and floating point numbers. * Removed prepare_for_tokenization() from tokenization_transfo_xl.py because punctuation is handled by sacremoses * Added corresponding tests * Removed test comapring TransfoXLTokenizer and TransfoXLTokenizerFast * Added deprecation warning to TransfoXLTokenizerFast * isort change Co-authored-by: Teven <teven.lescao@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -22,11 +22,13 @@ import glob
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import warnings
|
||||
from collections import Counter, OrderedDict
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sacremoses as sm
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.implementations import BaseTokenizer
|
||||
from tokenizers.models import WordLevel
|
||||
@@ -70,6 +72,47 @@ PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
||||
}
|
||||
CORPUS_NAME = "corpus.bin"
|
||||
|
||||
MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ "
|
||||
DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")]
|
||||
|
||||
|
||||
def tokenize_numbers(text_array: List[str]) -> List[str]:
|
||||
"""
|
||||
Splits large comma-separated numbers and floating point values.
|
||||
This is done by replacing commas with ' @,@ ' and dots with ' @.@ '.
|
||||
Args:
|
||||
text_array: An already tokenized text as list
|
||||
Returns:
|
||||
A list of strings with tokenized numbers
|
||||
Example::
|
||||
>>> tokenize_numbers(["$", "5,000", "1.73", "m"])
|
||||
["$", "5", "@,@", "000", "1", "@.@", "73", "m"]
|
||||
"""
|
||||
tokenized = []
|
||||
for i in range(len(text_array)):
|
||||
reg, sub = MATCH_NUMBERS
|
||||
replaced = re.sub(reg, sub, text_array[i]).split()
|
||||
tokenized.extend(replaced)
|
||||
|
||||
return tokenized
|
||||
|
||||
|
||||
def detokenize_numbers(text: str) -> str:
|
||||
"""
|
||||
Inverts the operation of `tokenize_numbers`.
|
||||
This is replacing ' @,@ ' and ' @.@' by ',' and '.'.
|
||||
Args:
|
||||
text: A string where the number should be detokenized
|
||||
Returns:
|
||||
A detokenized string
|
||||
Example::
|
||||
>>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m")
|
||||
"$ 5,000 1.73 m"
|
||||
"""
|
||||
for reg, sub in DETOKENIZE_NUMBERS:
|
||||
text = re.sub(reg, sub, text)
|
||||
return text
|
||||
|
||||
|
||||
class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
@@ -97,6 +140,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
unk_token="<unk>",
|
||||
eos_token="<eos>",
|
||||
additional_special_tokens=["<formula>"],
|
||||
language="en",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -118,6 +162,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
|
||||
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
|
||||
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
|
||||
self.language = language
|
||||
self.moses_punct_normalizer = sm.MosesPunctNormalizer(language)
|
||||
self.moses_tokenizer = sm.MosesTokenizer(language)
|
||||
self.moses_detokenizer = sm.MosesDetokenizer(language)
|
||||
|
||||
try:
|
||||
if pretrained_vocab_file is not None:
|
||||
@@ -300,6 +348,34 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
del self.added_tokens_decoder[old_index]
|
||||
del self.added_tokens_encoder[token]
|
||||
|
||||
def moses_punct_norm(self, text):
|
||||
return self.moses_punct_normalizer.normalize(text)
|
||||
|
||||
def moses_tokenize(self, text):
|
||||
return self.moses_tokenizer.tokenize(
|
||||
text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split
|
||||
)
|
||||
|
||||
def moses_pipeline(self, text: str) -> List[str]:
|
||||
"""
|
||||
Does basic tokenization using :class:`sacremoses.MosesPunctNormalizer` and :class:`sacremoses.MosesTokenizer`
|
||||
with `aggressive_dash_splits=True` (see :func:`sacremoses.tokenize.MosesTokenizer.tokenize`).
|
||||
Additionally, large comma-separated numbers and floating point values are split.
|
||||
E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 people are 1 @.@ 80m tall".
|
||||
Args:
|
||||
text: Text to be tokenized
|
||||
Returns:
|
||||
A list of tokenized strings
|
||||
Example::
|
||||
>>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103")
|
||||
>>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall")
|
||||
['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall']
|
||||
"""
|
||||
text = self.moses_punct_norm(text)
|
||||
text = self.moses_tokenize(text)
|
||||
text = tokenize_numbers(text)
|
||||
return text
|
||||
|
||||
def _convert_id_to_token(self, idx):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
|
||||
@@ -323,9 +399,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = " ".join(tokens).strip()
|
||||
return out_string
|
||||
"""
|
||||
Converts a sequence of tokens (string) in a single string.
|
||||
Additionally, the split numbers are converted back into it's original form.
|
||||
"""
|
||||
out_string = self.moses_detokenizer.detokenize(tokens)
|
||||
return detokenize_numbers(out_string).strip()
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||
@@ -347,7 +426,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
if self.delimiter == "":
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
symbols = self.moses_pipeline(line)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ["<S>"] + symbols + ["<S>"]
|
||||
@@ -356,19 +435,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
return symbols
|
||||
|
||||
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
|
||||
# add spaces before punctuation symbols as should be done in transfo-xl
|
||||
add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
|
||||
if add_space_before_punct_symbol:
|
||||
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
|
||||
elif self.punction_without_space_before_pattern.search(text):
|
||||
# searches until the first occurence of a punctuation symbol without surrounding spaces
|
||||
logger.warning(
|
||||
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
|
||||
)
|
||||
|
||||
return (text, kwargs)
|
||||
|
||||
|
||||
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
|
||||
def __init__(
|
||||
@@ -484,6 +550,11 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"The class `TransfoXLTokenizerFast` is deprecated and will be removed in a future version. Please use `TransfoXLTokenizer` with it's enhanced tokenization instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
logger.warning(
|
||||
"Please note you will not be able to load the vocabulary in"
|
||||
|
||||
@@ -12,14 +12,12 @@ from transformers import (
|
||||
OpenAIGPTTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
RobertaTokenizer,
|
||||
TransfoXLTokenizer,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.tokenization_distilbert import DistilBertTokenizerFast
|
||||
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
|
||||
from transformers.tokenization_roberta import RobertaTokenizerFast
|
||||
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -895,17 +893,3 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
|
||||
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_all_tokenizers(self):
|
||||
super().test_all_tokenizers()
|
||||
|
||||
@require_torch
|
||||
def test_pretokenized_tokenizers(self):
|
||||
super().test_pretokenized_tokenizers()
|
||||
|
||||
@@ -83,6 +83,44 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
|
||||
)
|
||||
|
||||
def test_full_tokenizer_moses_numbers(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=False)
|
||||
text_in = "Hello (bracket) and side-scrolled [and] Henry's $5,000 with 3.34 m. What's up!?"
|
||||
tokens_out = [
|
||||
"Hello",
|
||||
"(",
|
||||
"bracket",
|
||||
")",
|
||||
"and",
|
||||
"side",
|
||||
"@-@",
|
||||
"scrolled",
|
||||
"[",
|
||||
"and",
|
||||
"]",
|
||||
"Henry",
|
||||
"'s",
|
||||
"$",
|
||||
"5",
|
||||
"@,@",
|
||||
"000",
|
||||
"with",
|
||||
"3",
|
||||
"@.@",
|
||||
"34",
|
||||
"m",
|
||||
".",
|
||||
"What",
|
||||
"'s",
|
||||
"up",
|
||||
"!",
|
||||
"?",
|
||||
]
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize(text_in), tokens_out)
|
||||
|
||||
self.assertEqual(tokenizer.convert_tokens_to_string(tokens_out), text_in)
|
||||
|
||||
def test_move_added_token(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
original_len = len(tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user