Transformer-XL: Improved tokenization with sacremoses (#6322)
* Improved tokenization with sacremoses * The TransfoXLTokenizer is now using sacremoses for tokenization * Added tokenization of comma-separated and floating point numbers. * Removed prepare_for_tokenization() from tokenization_transfo_xl.py because punctuation is handled by sacremoses * Added corresponding tests * Removed test comapring TransfoXLTokenizer and TransfoXLTokenizerFast * Added deprecation warning to TransfoXLTokenizerFast * isort change Co-authored-by: Teven <teven.lescao@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -22,11 +22,13 @@ import glob
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter, OrderedDict
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import sacremoses as sm
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from tokenizers.implementations import BaseTokenizer
|
from tokenizers.implementations import BaseTokenizer
|
||||||
from tokenizers.models import WordLevel
|
from tokenizers.models import WordLevel
|
||||||
@@ -70,6 +72,47 @@ PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
CORPUS_NAME = "corpus.bin"
|
CORPUS_NAME = "corpus.bin"
|
||||||
|
|
||||||
|
MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ "
|
||||||
|
DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")]
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_numbers(text_array: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Splits large comma-separated numbers and floating point values.
|
||||||
|
This is done by replacing commas with ' @,@ ' and dots with ' @.@ '.
|
||||||
|
Args:
|
||||||
|
text_array: An already tokenized text as list
|
||||||
|
Returns:
|
||||||
|
A list of strings with tokenized numbers
|
||||||
|
Example::
|
||||||
|
>>> tokenize_numbers(["$", "5,000", "1.73", "m"])
|
||||||
|
["$", "5", "@,@", "000", "1", "@.@", "73", "m"]
|
||||||
|
"""
|
||||||
|
tokenized = []
|
||||||
|
for i in range(len(text_array)):
|
||||||
|
reg, sub = MATCH_NUMBERS
|
||||||
|
replaced = re.sub(reg, sub, text_array[i]).split()
|
||||||
|
tokenized.extend(replaced)
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
|
||||||
|
def detokenize_numbers(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Inverts the operation of `tokenize_numbers`.
|
||||||
|
This is replacing ' @,@ ' and ' @.@' by ',' and '.'.
|
||||||
|
Args:
|
||||||
|
text: A string where the number should be detokenized
|
||||||
|
Returns:
|
||||||
|
A detokenized string
|
||||||
|
Example::
|
||||||
|
>>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m")
|
||||||
|
"$ 5,000 1.73 m"
|
||||||
|
"""
|
||||||
|
for reg, sub in DETOKENIZE_NUMBERS:
|
||||||
|
text = re.sub(reg, sub, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLTokenizer(PreTrainedTokenizer):
|
class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||||
"""
|
"""
|
||||||
@@ -97,6 +140,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
eos_token="<eos>",
|
eos_token="<eos>",
|
||||||
additional_special_tokens=["<formula>"],
|
additional_special_tokens=["<formula>"],
|
||||||
|
language="en",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -118,6 +162,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
|
self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
|
||||||
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
|
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
|
||||||
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
|
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
|
||||||
|
self.language = language
|
||||||
|
self.moses_punct_normalizer = sm.MosesPunctNormalizer(language)
|
||||||
|
self.moses_tokenizer = sm.MosesTokenizer(language)
|
||||||
|
self.moses_detokenizer = sm.MosesDetokenizer(language)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if pretrained_vocab_file is not None:
|
if pretrained_vocab_file is not None:
|
||||||
@@ -300,6 +348,34 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
del self.added_tokens_decoder[old_index]
|
del self.added_tokens_decoder[old_index]
|
||||||
del self.added_tokens_encoder[token]
|
del self.added_tokens_encoder[token]
|
||||||
|
|
||||||
|
def moses_punct_norm(self, text):
|
||||||
|
return self.moses_punct_normalizer.normalize(text)
|
||||||
|
|
||||||
|
def moses_tokenize(self, text):
|
||||||
|
return self.moses_tokenizer.tokenize(
|
||||||
|
text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split
|
||||||
|
)
|
||||||
|
|
||||||
|
def moses_pipeline(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Does basic tokenization using :class:`sacremoses.MosesPunctNormalizer` and :class:`sacremoses.MosesTokenizer`
|
||||||
|
with `aggressive_dash_splits=True` (see :func:`sacremoses.tokenize.MosesTokenizer.tokenize`).
|
||||||
|
Additionally, large comma-separated numbers and floating point values are split.
|
||||||
|
E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 people are 1 @.@ 80m tall".
|
||||||
|
Args:
|
||||||
|
text: Text to be tokenized
|
||||||
|
Returns:
|
||||||
|
A list of tokenized strings
|
||||||
|
Example::
|
||||||
|
>>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103")
|
||||||
|
>>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall")
|
||||||
|
['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall']
|
||||||
|
"""
|
||||||
|
text = self.moses_punct_norm(text)
|
||||||
|
text = self.moses_tokenize(text)
|
||||||
|
text = tokenize_numbers(text)
|
||||||
|
return text
|
||||||
|
|
||||||
def _convert_id_to_token(self, idx):
|
def _convert_id_to_token(self, idx):
|
||||||
"""Converts an id in a token (BPE) using the vocab."""
|
"""Converts an id in a token (BPE) using the vocab."""
|
||||||
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
|
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
|
||||||
@@ -323,9 +399,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
|
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
""" Converts a sequence of tokens (string) in a single string. """
|
"""
|
||||||
out_string = " ".join(tokens).strip()
|
Converts a sequence of tokens (string) in a single string.
|
||||||
return out_string
|
Additionally, the split numbers are converted back into it's original form.
|
||||||
|
"""
|
||||||
|
out_string = self.moses_detokenizer.detokenize(tokens)
|
||||||
|
return detokenize_numbers(out_string).strip()
|
||||||
|
|
||||||
def convert_to_tensor(self, symbols):
|
def convert_to_tensor(self, symbols):
|
||||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||||
@@ -347,7 +426,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
if self.delimiter == "":
|
if self.delimiter == "":
|
||||||
symbols = line
|
symbols = line
|
||||||
else:
|
else:
|
||||||
symbols = line.split(self.delimiter)
|
symbols = self.moses_pipeline(line)
|
||||||
|
|
||||||
if add_double_eos: # lm1b
|
if add_double_eos: # lm1b
|
||||||
return ["<S>"] + symbols + ["<S>"]
|
return ["<S>"] + symbols + ["<S>"]
|
||||||
@@ -356,19 +435,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
return symbols
|
return symbols
|
||||||
|
|
||||||
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
|
|
||||||
# add spaces before punctuation symbols as should be done in transfo-xl
|
|
||||||
add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
|
|
||||||
if add_space_before_punct_symbol:
|
|
||||||
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
|
|
||||||
elif self.punction_without_space_before_pattern.search(text):
|
|
||||||
# searches until the first occurence of a punctuation symbol without surrounding spaces
|
|
||||||
logger.warning(
|
|
||||||
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
|
|
||||||
)
|
|
||||||
|
|
||||||
return (text, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
|
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -484,6 +550,11 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"The class `TransfoXLTokenizerFast` is deprecated and will be removed in a future version. Please use `TransfoXLTokenizer` with it's enhanced tokenization instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Please note you will not be able to load the vocabulary in"
|
"Please note you will not be able to load the vocabulary in"
|
||||||
|
|||||||
@@ -12,14 +12,12 @@ from transformers import (
|
|||||||
OpenAIGPTTokenizer,
|
OpenAIGPTTokenizer,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
TransfoXLTokenizer,
|
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch
|
from transformers.testing_utils import get_tests_dir, require_torch
|
||||||
from transformers.tokenization_distilbert import DistilBertTokenizerFast
|
from transformers.tokenization_distilbert import DistilBertTokenizerFast
|
||||||
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
|
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
|
||||||
from transformers.tokenization_roberta import RobertaTokenizerFast
|
from transformers.tokenization_roberta import RobertaTokenizerFast
|
||||||
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -895,17 +893,3 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
|
|
||||||
TOKENIZERS_CLASSES = frozenset(
|
|
||||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_all_tokenizers(self):
|
|
||||||
super().test_all_tokenizers()
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_pretokenized_tokenizers(self):
|
|
||||||
super().test_pretokenized_tokenizers()
|
|
||||||
|
|||||||
@@ -83,6 +83,44 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
|
tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_full_tokenizer_moses_numbers(self):
|
||||||
|
tokenizer = TransfoXLTokenizer(lower_case=False)
|
||||||
|
text_in = "Hello (bracket) and side-scrolled [and] Henry's $5,000 with 3.34 m. What's up!?"
|
||||||
|
tokens_out = [
|
||||||
|
"Hello",
|
||||||
|
"(",
|
||||||
|
"bracket",
|
||||||
|
")",
|
||||||
|
"and",
|
||||||
|
"side",
|
||||||
|
"@-@",
|
||||||
|
"scrolled",
|
||||||
|
"[",
|
||||||
|
"and",
|
||||||
|
"]",
|
||||||
|
"Henry",
|
||||||
|
"'s",
|
||||||
|
"$",
|
||||||
|
"5",
|
||||||
|
"@,@",
|
||||||
|
"000",
|
||||||
|
"with",
|
||||||
|
"3",
|
||||||
|
"@.@",
|
||||||
|
"34",
|
||||||
|
"m",
|
||||||
|
".",
|
||||||
|
"What",
|
||||||
|
"'s",
|
||||||
|
"up",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(tokenizer.tokenize(text_in), tokens_out)
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.convert_tokens_to_string(tokens_out), text_in)
|
||||||
|
|
||||||
def test_move_added_token(self):
|
def test_move_added_token(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
original_len = len(tokenizer)
|
original_len = len(tokenizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user