From 65d74c496512d279cbe1c10a201f1745a1afa98a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 24 Feb 2020 21:11:10 +0100 Subject: [PATCH] Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to (#2987) * add preprocessing to add space before punctuation for transfo_xl * improve warning messages * make style * compile regex at instantination of tokenizer object --- examples/run_generation.py | 6 ++++-- src/transformers/tokenization_transfo_xl.py | 22 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 5358d650b4..0652567b6b 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -59,7 +59,7 @@ MODEL_CLASSES = { # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # in https://github.com/rusiaaman/XLNet-gen#methodology # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e -PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family +PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, @@ -214,7 +214,9 @@ def main(): if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) - encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt") + encoded_prompt = tokenizer.encode( + preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True + ) else: encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = encoded_prompt.to(args.device) diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index f3f7ff3f31..6ca9cb46ce 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -22,6 +22,7 @@ import glob import logging import os import pickle +import re from collections import Counter, OrderedDict from typing import List, Optional, Tuple, Union @@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.delimiter = delimiter self.vocab_file = vocab_file self.never_split = never_split + self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605 + self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols)) + self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern() if pretrained_vocab_file is not None: # Hack because, honestly this tokenizer was not made to be used @@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer): if vocab_file is not None: self.build_vocab() + def _compile_space_around_punctuation_pattern(self): + look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols) + look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605 + return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space) + def count_file(self, path, verbose=False, add_eos=False): if verbose: logger.info("counting file {} ...".format(path)) @@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer): else: return symbols + def prepare_for_tokenization(self, text, **kwargs): + # add spaces before punctuation symbols as should be done in transfo-xl + + if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]: + text = self.punctuation_with_space_around_pattern.sub(r" ", text) + elif self.punction_without_space_before_pattern.search(text): + # searches until the first occurence of a punctuation symbol without surrounding spaces + logger.warning( + "You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `` token" + ) + + return text + class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): def __init__(