From cb276b41deaa97088e6d56fdca9ceb17db16ca19 Mon Sep 17 00:00:00 2001 From: RafaelWO <38643099+RafaelWO@users.noreply.github.com> Date: Fri, 28 Aug 2020 15:56:17 +0200 Subject: [PATCH] Transformer-XL: Improved tokenization with sacremoses (#6322) * Improved tokenization with sacremoses * The TransfoXLTokenizer is now using sacremoses for tokenization * Added tokenization of comma-separated and floating point numbers. * Removed prepare_for_tokenization() from tokenization_transfo_xl.py because punctuation is handled by sacremoses * Added corresponding tests * Removed test comapring TransfoXLTokenizer and TransfoXLTokenizerFast * Added deprecation warning to TransfoXLTokenizerFast * isort change Co-authored-by: Teven Co-authored-by: Lysandre Debut --- src/transformers/tokenization_transfo_xl.py | 107 ++++++++++++++++---- tests/test_tokenization_fast.py | 16 --- tests/test_tokenization_transfo_xl.py | 38 +++++++ 3 files changed, 127 insertions(+), 34 deletions(-) diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 3f9035a5e0..8a677a713a 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -22,11 +22,13 @@ import glob import os import pickle import re +import warnings from collections import Counter, OrderedDict -from typing import Optional +from typing import List, Optional import numpy as np +import sacremoses as sm from tokenizers import Tokenizer from tokenizers.implementations import BaseTokenizer from tokenizers.models import WordLevel @@ -70,6 +72,47 @@ PRETRAINED_CORPUS_ARCHIVE_MAP = { } CORPUS_NAME = "corpus.bin" +MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ " +DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")] + + +def tokenize_numbers(text_array: List[str]) -> List[str]: + """ + Splits large comma-separated numbers and floating point values. + This is done by replacing commas with ' @,@ ' and dots with ' @.@ '. + Args: + text_array: An already tokenized text as list + Returns: + A list of strings with tokenized numbers + Example:: + >>> tokenize_numbers(["$", "5,000", "1.73", "m"]) + ["$", "5", "@,@", "000", "1", "@.@", "73", "m"] + """ + tokenized = [] + for i in range(len(text_array)): + reg, sub = MATCH_NUMBERS + replaced = re.sub(reg, sub, text_array[i]).split() + tokenized.extend(replaced) + + return tokenized + + +def detokenize_numbers(text: str) -> str: + """ + Inverts the operation of `tokenize_numbers`. + This is replacing ' @,@ ' and ' @.@' by ',' and '.'. + Args: + text: A string where the number should be detokenized + Returns: + A detokenized string + Example:: + >>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m") + "$ 5,000 1.73 m" + """ + for reg, sub in DETOKENIZE_NUMBERS: + text = re.sub(reg, sub, text) + return text + class TransfoXLTokenizer(PreTrainedTokenizer): """ @@ -97,6 +140,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): unk_token="", eos_token="", additional_special_tokens=[""], + language="en", **kwargs ): super().__init__( @@ -118,6 +162,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~' self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols)) self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern() + self.language = language + self.moses_punct_normalizer = sm.MosesPunctNormalizer(language) + self.moses_tokenizer = sm.MosesTokenizer(language) + self.moses_detokenizer = sm.MosesDetokenizer(language) try: if pretrained_vocab_file is not None: @@ -300,6 +348,34 @@ class TransfoXLTokenizer(PreTrainedTokenizer): del self.added_tokens_decoder[old_index] del self.added_tokens_encoder[token] + def moses_punct_norm(self, text): + return self.moses_punct_normalizer.normalize(text) + + def moses_tokenize(self, text): + return self.moses_tokenizer.tokenize( + text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split + ) + + def moses_pipeline(self, text: str) -> List[str]: + """ + Does basic tokenization using :class:`sacremoses.MosesPunctNormalizer` and :class:`sacremoses.MosesTokenizer` + with `aggressive_dash_splits=True` (see :func:`sacremoses.tokenize.MosesTokenizer.tokenize`). + Additionally, large comma-separated numbers and floating point values are split. + E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 people are 1 @.@ 80m tall". + Args: + text: Text to be tokenized + Returns: + A list of tokenized strings + Example:: + >>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103") + >>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall") + ['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall'] + """ + text = self.moses_punct_norm(text) + text = self.moses_tokenize(text) + text = tokenize_numbers(text) + return text + def _convert_id_to_token(self, idx): """Converts an id in a token (BPE) using the vocab.""" assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx) @@ -323,9 +399,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer): raise ValueError("Token not in vocabulary and no token in vocabulary for replacement") def convert_tokens_to_string(self, tokens): - """ Converts a sequence of tokens (string) in a single string. """ - out_string = " ".join(tokens).strip() - return out_string + """ + Converts a sequence of tokens (string) in a single string. + Additionally, the split numbers are converted back into it's original form. + """ + out_string = self.moses_detokenizer.detokenize(tokens) + return detokenize_numbers(out_string).strip() def convert_to_tensor(self, symbols): return torch.LongTensor(self.convert_tokens_to_ids(symbols)) @@ -347,7 +426,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): if self.delimiter == "": symbols = line else: - symbols = line.split(self.delimiter) + symbols = self.moses_pipeline(line) if add_double_eos: # lm1b return [""] + symbols + [""] @@ -356,19 +435,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer): else: return symbols - def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs): - # add spaces before punctuation symbols as should be done in transfo-xl - add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False) - if add_space_before_punct_symbol: - text = self.punctuation_with_space_around_pattern.sub(r" ", text) - elif self.punction_without_space_before_pattern.search(text): - # searches until the first occurence of a punctuation symbol without surrounding spaces - logger.warning( - "You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `` token" - ) - - return (text, kwargs) - class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): def __init__( @@ -484,6 +550,11 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) + warnings.warn( + "The class `TransfoXLTokenizerFast` is deprecated and will be removed in a future version. Please use `TransfoXLTokenizer` with it's enhanced tokenization instead.", + FutureWarning, + ) + def save_pretrained(self, save_directory): logger.warning( "Please note you will not be able to load the vocabulary in" diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index a0a9d49646..ba466c45d5 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -12,14 +12,12 @@ from transformers import ( OpenAIGPTTokenizer, PreTrainedTokenizer, RobertaTokenizer, - TransfoXLTokenizer, is_torch_available, ) from transformers.testing_utils import get_tests_dir, require_torch from transformers.tokenization_distilbert import DistilBertTokenizerFast from transformers.tokenization_openai import OpenAIGPTTokenizerFast from transformers.tokenization_roberta import RobertaTokenizerFast -from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast logger = logging.getLogger(__name__) @@ -895,17 +893,3 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest): max_length=max_length, padding="max_length", ) - - -class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest): - TOKENIZERS_CLASSES = frozenset( - [Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)] - ) - - @require_torch - def test_all_tokenizers(self): - super().test_all_tokenizers() - - @require_torch - def test_pretokenized_tokenizers(self): - super().test_pretokenized_tokenizers() diff --git a/tests/test_tokenization_transfo_xl.py b/tests/test_tokenization_transfo_xl.py index 7f4dca4725..1688a9f3a6 100644 --- a/tests/test_tokenization_transfo_xl.py +++ b/tests/test_tokenization_transfo_xl.py @@ -83,6 +83,44 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + def test_full_tokenizer_moses_numbers(self): + tokenizer = TransfoXLTokenizer(lower_case=False) + text_in = "Hello (bracket) and side-scrolled [and] Henry's $5,000 with 3.34 m. What's up!?" + tokens_out = [ + "Hello", + "(", + "bracket", + ")", + "and", + "side", + "@-@", + "scrolled", + "[", + "and", + "]", + "Henry", + "'s", + "$", + "5", + "@,@", + "000", + "with", + "3", + "@.@", + "34", + "m", + ".", + "What", + "'s", + "up", + "!", + "?", + ] + + self.assertListEqual(tokenizer.tokenize(text_in), tokens_out) + + self.assertEqual(tokenizer.convert_tokens_to_string(tokens_out), text_in) + def test_move_added_token(self): tokenizer = self.get_tokenizer() original_len = len(tokenizer)